Skip to content

Commit e9fee7c

Browse files
committed
better manage compilation cache.
1 parent a17d537 commit e9fee7c

File tree

5 files changed

+24
-101
lines changed

5 files changed

+24
-101
lines changed

tests/models/test_modeling_common.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,14 +1721,14 @@ class TorchCompileTesterMixin:
17211721
def setUp(self):
17221722
# clean up the VRAM before each test
17231723
super().setUp()
1724-
torch._dynamo.reset()
1724+
torch.compiler.reset()
17251725
gc.collect()
17261726
backend_empty_cache(torch_device)
17271727

17281728
def tearDown(self):
17291729
# clean up the VRAM after each test in case of CUDA runtime errors
17301730
super().tearDown()
1731-
torch._dynamo.reset()
1731+
torch.compiler.reset()
17321732
gc.collect()
17331733
backend_empty_cache(torch_device)
17341734

@@ -1737,13 +1737,17 @@ def tearDown(self):
17371737
@is_torch_compile
17381738
@slow
17391739
def test_torch_compile_recompilation_and_graph_break(self):
1740-
torch._dynamo.reset()
1740+
torch.compiler.reset()
17411741
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17421742

17431743
model = self.model_class(**init_dict).to(torch_device)
17441744
model = torch.compile(model, fullgraph=True)
17451745

1746-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
1746+
with (
1747+
torch._inductor.utils.fresh_inductor_cache(),
1748+
torch._dynamo.config.patch(error_on_recompile=True),
1749+
torch.no_grad(),
1750+
):
17471751
_ = model(**inputs_dict)
17481752
_ = model(**inputs_dict)
17491753

@@ -1771,7 +1775,7 @@ def tearDown(self):
17711775
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
17721776
# there will be recompilation errors, as torch caches the model when run in the same process.
17731777
super().tearDown()
1774-
torch._dynamo.reset()
1778+
torch.compiler.reset()
17751779
gc.collect()
17761780
backend_empty_cache(torch_device)
17771781

@@ -1905,21 +1909,21 @@ def test_hotswapping_model(self, rank0, rank1):
19051909
def test_hotswapping_compiled_model_linear(self, rank0, rank1):
19061910
# It's important to add this context to raise an error on recompilation
19071911
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
1908-
with torch._dynamo.config.patch(error_on_recompile=True):
1912+
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
19091913
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
19101914

19111915
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
19121916
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
19131917
# It's important to add this context to raise an error on recompilation
19141918
target_modules = ["conv", "conv1", "conv2"]
1915-
with torch._dynamo.config.patch(error_on_recompile=True):
1919+
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
19161920
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
19171921

19181922
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
19191923
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
19201924
# It's important to add this context to raise an error on recompilation
19211925
target_modules = ["to_q", "conv"]
1922-
with torch._dynamo.config.patch(error_on_recompile=True):
1926+
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
19231927
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
19241928

19251929
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):

tests/models/transformers/test_models_transformer_hunyuan_video.py

Lines changed: 7 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,16 @@
1919
from diffusers import HunyuanVideoTransformer3DModel
2020
from diffusers.utils.testing_utils import (
2121
enable_full_determinism,
22-
is_torch_compile,
23-
require_torch_2,
24-
require_torch_gpu,
25-
slow,
2622
torch_device,
2723
)
2824

29-
from ..test_modeling_common import ModelTesterMixin
25+
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
3026

3127

3228
enable_full_determinism()
3329

3430

35-
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
31+
class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
3632
model_class = HunyuanVideoTransformer3DModel
3733
main_input_name = "hidden_states"
3834
uses_custom_attn_processor = True
@@ -96,23 +92,8 @@ def test_gradient_checkpointing_is_applied(self):
9692
expected_set = {"HunyuanVideoTransformer3DModel"}
9793
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
9894

99-
@require_torch_gpu
100-
@require_torch_2
101-
@is_torch_compile
102-
@slow
103-
def test_torch_compile_recompilation_and_graph_break(self):
104-
torch._dynamo.reset()
105-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10695

107-
model = self.model_class(**init_dict).to(torch_device)
108-
model = torch.compile(model, fullgraph=True)
109-
110-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
111-
_ = model(**inputs_dict)
112-
_ = model(**inputs_dict)
113-
114-
115-
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
96+
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
11697
model_class = HunyuanVideoTransformer3DModel
11798
main_input_name = "hidden_states"
11899
uses_custom_attn_processor = True
@@ -179,23 +160,8 @@ def test_gradient_checkpointing_is_applied(self):
179160
expected_set = {"HunyuanVideoTransformer3DModel"}
180161
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
181162

182-
@require_torch_gpu
183-
@require_torch_2
184-
@is_torch_compile
185-
@slow
186-
def test_torch_compile_recompilation_and_graph_break(self):
187-
torch._dynamo.reset()
188-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
189-
190-
model = self.model_class(**init_dict).to(torch_device)
191-
model = torch.compile(model, fullgraph=True)
192-
193-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
194-
_ = model(**inputs_dict)
195-
_ = model(**inputs_dict)
196-
197163

198-
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
164+
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
199165
model_class = HunyuanVideoTransformer3DModel
200166
main_input_name = "hidden_states"
201167
uses_custom_attn_processor = True
@@ -260,23 +226,10 @@ def test_gradient_checkpointing_is_applied(self):
260226
expected_set = {"HunyuanVideoTransformer3DModel"}
261227
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
262228

263-
@require_torch_gpu
264-
@require_torch_2
265-
@is_torch_compile
266-
@slow
267-
def test_torch_compile_recompilation_and_graph_break(self):
268-
torch._dynamo.reset()
269-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
270229

271-
model = self.model_class(**init_dict).to(torch_device)
272-
model = torch.compile(model, fullgraph=True)
273-
274-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
275-
_ = model(**inputs_dict)
276-
_ = model(**inputs_dict)
277-
278-
279-
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
230+
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
231+
ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase
232+
):
280233
model_class = HunyuanVideoTransformer3DModel
281234
main_input_name = "hidden_states"
282235
uses_custom_attn_processor = True
@@ -342,18 +295,3 @@ def test_output(self):
342295
def test_gradient_checkpointing_is_applied(self):
343296
expected_set = {"HunyuanVideoTransformer3DModel"}
344297
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
345-
346-
@require_torch_gpu
347-
@require_torch_2
348-
@is_torch_compile
349-
@slow
350-
def test_torch_compile_recompilation_and_graph_break(self):
351-
torch._dynamo.reset()
352-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
353-
354-
model = self.model_class(**init_dict).to(torch_device)
355-
model = torch.compile(model, fullgraph=True)
356-
357-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
358-
_ = model(**inputs_dict)
359-
_ = model(**inputs_dict)

tests/models/transformers/test_models_transformer_wan.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,16 @@
1919
from diffusers import WanTransformer3DModel
2020
from diffusers.utils.testing_utils import (
2121
enable_full_determinism,
22-
is_torch_compile,
23-
require_torch_2,
24-
require_torch_gpu,
25-
slow,
2622
torch_device,
2723
)
2824

29-
from ..test_modeling_common import ModelTesterMixin
25+
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
3026

3127

3228
enable_full_determinism()
3329

3430

35-
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
31+
class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
3632
model_class = WanTransformer3DModel
3733
main_input_name = "hidden_states"
3834
uses_custom_attn_processor = True
@@ -86,18 +82,3 @@ def prepare_init_args_and_inputs_for_common(self):
8682
def test_gradient_checkpointing_is_applied(self):
8783
expected_set = {"WanTransformer3DModel"}
8884
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
89-
90-
@require_torch_gpu
91-
@require_torch_2
92-
@is_torch_compile
93-
@slow
94-
def test_torch_compile_recompilation_and_graph_break(self):
95-
torch._dynamo.reset()
96-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
97-
98-
model = self.model_class(**init_dict).to(torch_device)
99-
model = torch.compile(model, fullgraph=True)
100-
101-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
102-
_ = model(**inputs_dict)
103-
_ = model(**inputs_dict)

tests/pipelines/test_pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2209,7 +2209,7 @@ def tearDown(self):
22092209
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
22102210
# there will be recompilation errors, as torch caches the model when run in the same process.
22112211
super().tearDown()
2212-
torch._dynamo.reset()
2212+
torch.compiler.reset()
22132213
gc.collect()
22142214
backend_empty_cache(torch_device)
22152215

tests/pipelines/test_pipelines_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,14 +1111,14 @@ def callback_cfg_params(self) -> frozenset:
11111111
def setUp(self):
11121112
# clean up the VRAM before each test
11131113
super().setUp()
1114-
torch._dynamo.reset()
1114+
torch.compiler.reset()
11151115
gc.collect()
11161116
backend_empty_cache(torch_device)
11171117

11181118
def tearDown(self):
11191119
# clean up the VRAM after each test in case of CUDA runtime errors
11201120
super().tearDown()
1121-
torch._dynamo.reset()
1121+
torch.compiler.reset()
11221122
gc.collect()
11231123
backend_empty_cache(torch_device)
11241124

0 commit comments

Comments
 (0)