|
19 | 19 | from diffusers import HunyuanVideoTransformer3DModel |
20 | 20 | from diffusers.utils.testing_utils import ( |
21 | 21 | enable_full_determinism, |
22 | | - is_torch_compile, |
23 | | - require_torch_2, |
24 | | - require_torch_gpu, |
25 | | - slow, |
26 | 22 | torch_device, |
27 | 23 | ) |
28 | 24 |
|
29 | | -from ..test_modeling_common import ModelTesterMixin |
| 25 | +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin |
30 | 26 |
|
31 | 27 |
|
32 | 28 | enable_full_determinism() |
33 | 29 |
|
34 | 30 |
|
35 | | -class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): |
| 31 | +class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): |
36 | 32 | model_class = HunyuanVideoTransformer3DModel |
37 | 33 | main_input_name = "hidden_states" |
38 | 34 | uses_custom_attn_processor = True |
@@ -96,23 +92,8 @@ def test_gradient_checkpointing_is_applied(self): |
96 | 92 | expected_set = {"HunyuanVideoTransformer3DModel"} |
97 | 93 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
98 | 94 |
|
99 | | - @require_torch_gpu |
100 | | - @require_torch_2 |
101 | | - @is_torch_compile |
102 | | - @slow |
103 | | - def test_torch_compile_recompilation_and_graph_break(self): |
104 | | - torch._dynamo.reset() |
105 | | - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
106 | 95 |
|
107 | | - model = self.model_class(**init_dict).to(torch_device) |
108 | | - model = torch.compile(model, fullgraph=True) |
109 | | - |
110 | | - with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): |
111 | | - _ = model(**inputs_dict) |
112 | | - _ = model(**inputs_dict) |
113 | | - |
114 | | - |
115 | | -class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): |
| 96 | +class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): |
116 | 97 | model_class = HunyuanVideoTransformer3DModel |
117 | 98 | main_input_name = "hidden_states" |
118 | 99 | uses_custom_attn_processor = True |
@@ -179,23 +160,8 @@ def test_gradient_checkpointing_is_applied(self): |
179 | 160 | expected_set = {"HunyuanVideoTransformer3DModel"} |
180 | 161 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
181 | 162 |
|
182 | | - @require_torch_gpu |
183 | | - @require_torch_2 |
184 | | - @is_torch_compile |
185 | | - @slow |
186 | | - def test_torch_compile_recompilation_and_graph_break(self): |
187 | | - torch._dynamo.reset() |
188 | | - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
189 | | - |
190 | | - model = self.model_class(**init_dict).to(torch_device) |
191 | | - model = torch.compile(model, fullgraph=True) |
192 | | - |
193 | | - with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): |
194 | | - _ = model(**inputs_dict) |
195 | | - _ = model(**inputs_dict) |
196 | | - |
197 | 163 |
|
198 | | -class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): |
| 164 | +class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): |
199 | 165 | model_class = HunyuanVideoTransformer3DModel |
200 | 166 | main_input_name = "hidden_states" |
201 | 167 | uses_custom_attn_processor = True |
@@ -260,23 +226,10 @@ def test_gradient_checkpointing_is_applied(self): |
260 | 226 | expected_set = {"HunyuanVideoTransformer3DModel"} |
261 | 227 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
262 | 228 |
|
263 | | - @require_torch_gpu |
264 | | - @require_torch_2 |
265 | | - @is_torch_compile |
266 | | - @slow |
267 | | - def test_torch_compile_recompilation_and_graph_break(self): |
268 | | - torch._dynamo.reset() |
269 | | - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
270 | 229 |
|
271 | | - model = self.model_class(**init_dict).to(torch_device) |
272 | | - model = torch.compile(model, fullgraph=True) |
273 | | - |
274 | | - with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): |
275 | | - _ = model(**inputs_dict) |
276 | | - _ = model(**inputs_dict) |
277 | | - |
278 | | - |
279 | | -class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): |
| 230 | +class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests( |
| 231 | + ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase |
| 232 | +): |
280 | 233 | model_class = HunyuanVideoTransformer3DModel |
281 | 234 | main_input_name = "hidden_states" |
282 | 235 | uses_custom_attn_processor = True |
@@ -342,18 +295,3 @@ def test_output(self): |
342 | 295 | def test_gradient_checkpointing_is_applied(self): |
343 | 296 | expected_set = {"HunyuanVideoTransformer3DModel"} |
344 | 297 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
345 | | - |
346 | | - @require_torch_gpu |
347 | | - @require_torch_2 |
348 | | - @is_torch_compile |
349 | | - @slow |
350 | | - def test_torch_compile_recompilation_and_graph_break(self): |
351 | | - torch._dynamo.reset() |
352 | | - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
353 | | - |
354 | | - model = self.model_class(**init_dict).to(torch_device) |
355 | | - model = torch.compile(model, fullgraph=True) |
356 | | - |
357 | | - with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): |
358 | | - _ = model(**inputs_dict) |
359 | | - _ = model(**inputs_dict) |
0 commit comments