|  | 
| 19 | 19 | from diffusers import HunyuanVideoTransformer3DModel | 
| 20 | 20 | from diffusers.utils.testing_utils import ( | 
| 21 | 21 |     enable_full_determinism, | 
| 22 |  | -    is_torch_compile, | 
| 23 |  | -    require_torch_2, | 
| 24 |  | -    require_torch_gpu, | 
| 25 |  | -    slow, | 
| 26 | 22 |     torch_device, | 
| 27 | 23 | ) | 
| 28 | 24 | 
 | 
| 29 |  | -from ..test_modeling_common import ModelTesterMixin | 
|  | 25 | +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin | 
| 30 | 26 | 
 | 
| 31 | 27 | 
 | 
| 32 | 28 | enable_full_determinism() | 
| 33 | 29 | 
 | 
| 34 | 30 | 
 | 
| 35 |  | -class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | 
|  | 31 | +class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): | 
| 36 | 32 |     model_class = HunyuanVideoTransformer3DModel | 
| 37 | 33 |     main_input_name = "hidden_states" | 
| 38 | 34 |     uses_custom_attn_processor = True | 
| @@ -96,23 +92,8 @@ def test_gradient_checkpointing_is_applied(self): | 
| 96 | 92 |         expected_set = {"HunyuanVideoTransformer3DModel"} | 
| 97 | 93 |         super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | 
| 98 | 94 | 
 | 
| 99 |  | -    @require_torch_gpu | 
| 100 |  | -    @require_torch_2 | 
| 101 |  | -    @is_torch_compile | 
| 102 |  | -    @slow | 
| 103 |  | -    def test_torch_compile_recompilation_and_graph_break(self): | 
| 104 |  | -        torch._dynamo.reset() | 
| 105 |  | -        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
| 106 | 95 | 
 | 
| 107 |  | -        model = self.model_class(**init_dict).to(torch_device) | 
| 108 |  | -        model = torch.compile(model, fullgraph=True) | 
| 109 |  | - | 
| 110 |  | -        with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): | 
| 111 |  | -            _ = model(**inputs_dict) | 
| 112 |  | -            _ = model(**inputs_dict) | 
| 113 |  | - | 
| 114 |  | - | 
| 115 |  | -class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | 
|  | 96 | +class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): | 
| 116 | 97 |     model_class = HunyuanVideoTransformer3DModel | 
| 117 | 98 |     main_input_name = "hidden_states" | 
| 118 | 99 |     uses_custom_attn_processor = True | 
| @@ -179,23 +160,8 @@ def test_gradient_checkpointing_is_applied(self): | 
| 179 | 160 |         expected_set = {"HunyuanVideoTransformer3DModel"} | 
| 180 | 161 |         super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | 
| 181 | 162 | 
 | 
| 182 |  | -    @require_torch_gpu | 
| 183 |  | -    @require_torch_2 | 
| 184 |  | -    @is_torch_compile | 
| 185 |  | -    @slow | 
| 186 |  | -    def test_torch_compile_recompilation_and_graph_break(self): | 
| 187 |  | -        torch._dynamo.reset() | 
| 188 |  | -        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
| 189 |  | - | 
| 190 |  | -        model = self.model_class(**init_dict).to(torch_device) | 
| 191 |  | -        model = torch.compile(model, fullgraph=True) | 
| 192 |  | - | 
| 193 |  | -        with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): | 
| 194 |  | -            _ = model(**inputs_dict) | 
| 195 |  | -            _ = model(**inputs_dict) | 
| 196 |  | - | 
| 197 | 163 | 
 | 
| 198 |  | -class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | 
|  | 164 | +class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): | 
| 199 | 165 |     model_class = HunyuanVideoTransformer3DModel | 
| 200 | 166 |     main_input_name = "hidden_states" | 
| 201 | 167 |     uses_custom_attn_processor = True | 
| @@ -260,23 +226,10 @@ def test_gradient_checkpointing_is_applied(self): | 
| 260 | 226 |         expected_set = {"HunyuanVideoTransformer3DModel"} | 
| 261 | 227 |         super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | 
| 262 | 228 | 
 | 
| 263 |  | -    @require_torch_gpu | 
| 264 |  | -    @require_torch_2 | 
| 265 |  | -    @is_torch_compile | 
| 266 |  | -    @slow | 
| 267 |  | -    def test_torch_compile_recompilation_and_graph_break(self): | 
| 268 |  | -        torch._dynamo.reset() | 
| 269 |  | -        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
| 270 | 229 | 
 | 
| 271 |  | -        model = self.model_class(**init_dict).to(torch_device) | 
| 272 |  | -        model = torch.compile(model, fullgraph=True) | 
| 273 |  | - | 
| 274 |  | -        with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): | 
| 275 |  | -            _ = model(**inputs_dict) | 
| 276 |  | -            _ = model(**inputs_dict) | 
| 277 |  | - | 
| 278 |  | - | 
| 279 |  | -class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | 
|  | 230 | +class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests( | 
|  | 231 | +    ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase | 
|  | 232 | +): | 
| 280 | 233 |     model_class = HunyuanVideoTransformer3DModel | 
| 281 | 234 |     main_input_name = "hidden_states" | 
| 282 | 235 |     uses_custom_attn_processor = True | 
| @@ -342,18 +295,3 @@ def test_output(self): | 
| 342 | 295 |     def test_gradient_checkpointing_is_applied(self): | 
| 343 | 296 |         expected_set = {"HunyuanVideoTransformer3DModel"} | 
| 344 | 297 |         super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | 
| 345 |  | - | 
| 346 |  | -    @require_torch_gpu | 
| 347 |  | -    @require_torch_2 | 
| 348 |  | -    @is_torch_compile | 
| 349 |  | -    @slow | 
| 350 |  | -    def test_torch_compile_recompilation_and_graph_break(self): | 
| 351 |  | -        torch._dynamo.reset() | 
| 352 |  | -        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
| 353 |  | - | 
| 354 |  | -        model = self.model_class(**init_dict).to(torch_device) | 
| 355 |  | -        model = torch.compile(model, fullgraph=True) | 
| 356 |  | - | 
| 357 |  | -        with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): | 
| 358 |  | -            _ = model(**inputs_dict) | 
| 359 |  | -            _ = model(**inputs_dict) | 
0 commit comments