|
19 | 19 | from diffusers import HunyuanVideoTransformer3DModel |
20 | 20 | from diffusers.utils.testing_utils import ( |
21 | 21 | enable_full_determinism, |
| 22 | + is_torch_compile, |
| 23 | + require_torch_2, |
| 24 | + require_torch_gpu, |
| 25 | + slow, |
22 | 26 | torch_device, |
23 | 27 | ) |
24 | 28 |
|
25 | | -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin |
| 29 | +from ..test_modeling_common import ModelTesterMixin |
26 | 30 |
|
27 | 31 |
|
28 | 32 | enable_full_determinism() |
29 | 33 |
|
30 | 34 |
|
31 | | -class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): |
| 35 | +class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): |
32 | 36 | model_class = HunyuanVideoTransformer3DModel |
33 | 37 | main_input_name = "hidden_states" |
34 | 38 | uses_custom_attn_processor = True |
@@ -92,8 +96,23 @@ def test_gradient_checkpointing_is_applied(self): |
92 | 96 | expected_set = {"HunyuanVideoTransformer3DModel"} |
93 | 97 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
94 | 98 |
|
| 99 | + @require_torch_gpu |
| 100 | + @require_torch_2 |
| 101 | + @is_torch_compile |
| 102 | + @slow |
| 103 | + def test_torch_compile_recompilation_and_graph_break(self): |
| 104 | + torch._dynamo.reset() |
| 105 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
95 | 106 |
|
96 | | -class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): |
| 107 | + model = self.model_class(**init_dict).to(torch_device) |
| 108 | + model = torch.compile(model, fullgraph=True) |
| 109 | + |
| 110 | + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): |
| 111 | + _ = model(**inputs_dict) |
| 112 | + _ = model(**inputs_dict) |
| 113 | + |
| 114 | + |
| 115 | +class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): |
97 | 116 | model_class = HunyuanVideoTransformer3DModel |
98 | 117 | main_input_name = "hidden_states" |
99 | 118 | uses_custom_attn_processor = True |
@@ -160,8 +179,23 @@ def test_gradient_checkpointing_is_applied(self): |
160 | 179 | expected_set = {"HunyuanVideoTransformer3DModel"} |
161 | 180 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
162 | 181 |
|
| 182 | + @require_torch_gpu |
| 183 | + @require_torch_2 |
| 184 | + @is_torch_compile |
| 185 | + @slow |
| 186 | + def test_torch_compile_recompilation_and_graph_break(self): |
| 187 | + torch._dynamo.reset() |
| 188 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 189 | + |
| 190 | + model = self.model_class(**init_dict).to(torch_device) |
| 191 | + model = torch.compile(model, fullgraph=True) |
| 192 | + |
| 193 | + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): |
| 194 | + _ = model(**inputs_dict) |
| 195 | + _ = model(**inputs_dict) |
| 196 | + |
163 | 197 |
|
164 | | -class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): |
| 198 | +class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): |
165 | 199 | model_class = HunyuanVideoTransformer3DModel |
166 | 200 | main_input_name = "hidden_states" |
167 | 201 | uses_custom_attn_processor = True |
@@ -226,10 +260,23 @@ def test_gradient_checkpointing_is_applied(self): |
226 | 260 | expected_set = {"HunyuanVideoTransformer3DModel"} |
227 | 261 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
228 | 262 |
|
| 263 | + @require_torch_gpu |
| 264 | + @require_torch_2 |
| 265 | + @is_torch_compile |
| 266 | + @slow |
| 267 | + def test_torch_compile_recompilation_and_graph_break(self): |
| 268 | + torch._dynamo.reset() |
| 269 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
229 | 270 |
|
230 | | -class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests( |
231 | | - ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase |
232 | | -): |
| 271 | + model = self.model_class(**init_dict).to(torch_device) |
| 272 | + model = torch.compile(model, fullgraph=True) |
| 273 | + |
| 274 | + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): |
| 275 | + _ = model(**inputs_dict) |
| 276 | + _ = model(**inputs_dict) |
| 277 | + |
| 278 | + |
| 279 | +class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): |
233 | 280 | model_class = HunyuanVideoTransformer3DModel |
234 | 281 | main_input_name = "hidden_states" |
235 | 282 | uses_custom_attn_processor = True |
@@ -295,3 +342,18 @@ def test_output(self): |
295 | 342 | def test_gradient_checkpointing_is_applied(self): |
296 | 343 | expected_set = {"HunyuanVideoTransformer3DModel"} |
297 | 344 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
| 345 | + |
| 346 | + @require_torch_gpu |
| 347 | + @require_torch_2 |
| 348 | + @is_torch_compile |
| 349 | + @slow |
| 350 | + def test_torch_compile_recompilation_and_graph_break(self): |
| 351 | + torch._dynamo.reset() |
| 352 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 353 | + |
| 354 | + model = self.model_class(**init_dict).to(torch_device) |
| 355 | + model = torch.compile(model, fullgraph=True) |
| 356 | + |
| 357 | + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): |
| 358 | + _ = model(**inputs_dict) |
| 359 | + _ = model(**inputs_dict) |
0 commit comments