Skip to content

Commit f742b32

Browse files
StrongerXipytorchmergebot
authored andcommitted
[dynamo] Avoid recompiling over unused objects (pytorch#156891)
Dynamo was aggressively specializing on lazy VTs over `set_name_hint` in `STORE_FAST`, etc., and `isinstance` in `LOAD_FAST_CHECK`. This causes regional `torch.compile` from optimizing ComfyUI GGUF + LoRA to either (1). exceed the recompialtion limit of 8, which results in suboptimal performance, and (2). even if recompilation limit is increased, the compilation time gets unnecessarily high (180s v.s. 20s for Flux). This patch fixes the recompilation issue. Pull Request resolved: pytorch#156891 Approved by: https://github.com/williamwen42, https://github.com/mlazos
1 parent 317520b commit f742b32

12 files changed

+76
-21
lines changed

benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ mobilenet_v2,pass,0
210210

211211

212212

213-
mobilenet_v2_quantized_qat,pass,2
213+
mobilenet_v2_quantized_qat,pass,3
214214

215215

216216

@@ -274,7 +274,7 @@ resnet50,pass,0
274274

275275

276276

277-
resnet50_quantized_qat,pass,2
277+
resnet50_quantized_qat,pass,3
278278

279279

280280

benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ mobilenet_v2,pass,0
210210

211211

212212

213-
mobilenet_v2_quantized_qat,pass,2
213+
mobilenet_v2_quantized_qat,pass,3
214214

215215

216216

@@ -274,7 +274,7 @@ resnet50,pass,0
274274

275275

276276

277-
resnet50_quantized_qat,pass,2
277+
resnet50_quantized_qat,pass,3
278278

279279

280280

benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ mobilenet_v2,pass,0
210210

211211

212212

213-
mobilenet_v2_quantized_qat,pass,2
213+
mobilenet_v2_quantized_qat,pass,3
214214

215215

216216

@@ -274,7 +274,7 @@ resnet50,pass,0
274274

275275

276276

277-
resnet50_quantized_qat,pass,2
277+
resnet50_quantized_qat,pass,3
278278

279279

280280

benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ mobilenet_v2,pass,0
194194

195195

196196

197-
mobilenet_v2_quantized_qat,pass,2
197+
mobilenet_v2_quantized_qat,pass,3
198198

199199

200200

@@ -258,7 +258,7 @@ resnet50,pass,0
258258

259259

260260

261-
resnet50_quantized_qat,pass,2
261+
resnet50_quantized_qat,pass,3
262262

263263

264264

benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ mobilenet_v2,pass,0
210210

211211

212212

213-
mobilenet_v2_quantized_qat,pass,2
213+
mobilenet_v2_quantized_qat,pass,3
214214

215215

216216

@@ -274,7 +274,7 @@ resnet50,pass,0
274274

275275

276276

277-
resnet50_quantized_qat,pass,2
277+
resnet50_quantized_qat,pass,3
278278

279279

280280

test/dynamo/test_higher_order_ops.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4483,15 +4483,13 @@ def wrapper_fn(model, params, buffers, inputs):
44834483
if torch._dynamo.config.inline_inbuilt_nn_modules:
44844484
expected = """\
44854485
class GraphModule(torch.nn.Module):
4486-
def forward(self, L_params_l1_weight_: "f32[1, 1]", L_params_l1_bias_: "f32[1]", L_buffers_buffer_: "f32[1]", L_inputs_: "f32[1, 1]"):
4487-
l_params_l1_weight_ = L_params_l1_weight_
4488-
l_params_l1_bias_ = L_params_l1_bias_
4489-
l_buffers_buffer_ = L_buffers_buffer_
4486+
def forward(self, L_inputs_: "f32[1, 1]", L_model_modules_l1_parameters_weight_: "f32[1, 1]", L_model_modules_l1_parameters_bias_: "f32[1]", L_model_buffers_buffer_: "f32[1]"):
44904487
l_inputs_ = L_inputs_
4491-
4492-
linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_params_l1_weight_, l_params_l1_bias_); l_inputs_ = l_params_l1_weight_ = l_params_l1_bias_ = None
4493-
4494-
add: "f32[1, 1]" = linear + l_buffers_buffer_; linear = l_buffers_buffer_ = None
4488+
l_model_modules_l1_parameters_weight_ = L_model_modules_l1_parameters_weight_
4489+
l_model_modules_l1_parameters_bias_ = L_model_modules_l1_parameters_bias_
4490+
l_model_buffers_buffer_ = L_model_buffers_buffer_
4491+
linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_model_modules_l1_parameters_weight_, l_model_modules_l1_parameters_bias_); l_inputs_ = l_model_modules_l1_parameters_weight_ = l_model_modules_l1_parameters_bias_ = None
4492+
add: "f32[1, 1]" = linear + l_model_buffers_buffer_; linear = l_model_buffers_buffer_ = None
44954493
return (add,)
44964494
"""
44974495
# We found Windows/Linux have some empty line difference, empty_line_normalizer will help fix it.

test/dynamo/test_misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6823,7 +6823,7 @@ def fn(x):
68236823
# assign fstring to a variable causes the fstring to be used,
68246824
# which realizes the variable tracker.
68256825
f_str = f"{x.shape[0]}"
6826-
return x.sin()
6826+
return x.sin(), f_str
68276827

68286828
guard_failure = None
68296829

test/dynamo/test_recompiles.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,29 @@ def f(x, foo):
499499
f(x, foo1)
500500
self.assertEqual(counter.frame_count, 2)
501501

502+
def test_no_recompile_over_unused_objects(self):
503+
# This is a regression test case that imitates
504+
# https://github.com/city96/ComfyUI-GGUF/blob/47bec6147569a138dd30ad3e14f190a36a3be456/ops.py#L169-L182
505+
counter = torch._dynamo.testing.CompileCounter()
506+
507+
def f(x, key, patches):
508+
return x * x + 1
509+
510+
@torch.compile(backend=counter, fullgraph=True)
511+
def apply_patches(f, x, keys):
512+
patches = []
513+
for key, patch in keys: # noqa: F402
514+
patches.append(patch)
515+
x = f(x, key, patches)
516+
return x
517+
518+
# no recompilation
519+
x = torch.rand(10)
520+
apply_patches(f, x, [("a", 1), ("b", 2)])
521+
self.assertEqual(counter.frame_count, 1)
522+
apply_patches(f, x, [("c", 3), ("d", 4)])
523+
self.assertEqual(counter.frame_count, 1)
524+
502525

503526
if __name__ == "__main__":
504527
from torch._dynamo.test_case import run_tests

test/torch_np/numpy_tests/core/test_multiarray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3779,6 +3779,7 @@ def test_datetime(self):
37793779
expected_idx = np.array([2, 1, 0])
37803780
assert_array_equal(idx, expected_idx)
37813781

3782+
@xfail # GH issue #157720
37823783
def test_object(self): # gh-6312
37833784
a = np.random.choice(10, 1000)
37843785
b = np.random.choice(["abc", "xy", "wz", "efghi", "qwst", "x"], 1000)

torch/_dynamo/symbolic_convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3035,7 +3035,7 @@ def END_FOR(self, inst):
30353035
self.popn(2)
30363036

30373037
def LOAD_FAST_CHECK(self, inst):
3038-
if isinstance(self.symbolic_locals.get(inst.argval, None), NullVariable):
3038+
if istype(self.symbolic_locals.get(inst.argval, None), NullVariable):
30393039
unimplemented_v2(
30403040
gb_type="LOAD_FAST_CHECK on uninitialized variable",
30413041
context=inst.argval,

0 commit comments

Comments
 (0)