|  | 
| 17 | 17 | import torch | 
| 18 | 18 | 
 | 
| 19 | 19 | from diffusers import HunyuanVideoTransformer3DModel | 
| 20 |  | -from diffusers.utils.testing_utils import enable_full_determinism, torch_device | 
| 21 |  | -from diffusers.utils.testing_utils import require_torch_gpu, require_torch_2, is_torch_compile, slow | 
|  | 20 | +from diffusers.utils.testing_utils import ( | 
|  | 21 | +    enable_full_determinism, | 
|  | 22 | +    is_torch_compile, | 
|  | 23 | +    require_torch_2, | 
|  | 24 | +    require_torch_gpu, | 
|  | 25 | +    slow, | 
|  | 26 | +    torch_device, | 
|  | 27 | +) | 
| 22 | 28 | 
 | 
| 23 | 29 | from ..test_modeling_common import ModelTesterMixin | 
| 24 | 30 | 
 | 
| @@ -86,18 +92,14 @@ def prepare_init_args_and_inputs_for_common(self): | 
| 86 | 92 |         inputs_dict = self.dummy_input | 
| 87 | 93 |         return init_dict, inputs_dict | 
| 88 | 94 | 
 | 
| 89 |  | -    def test_gradient_checkpointing_is_applied(self): | 
| 90 |  | -        expected_set = {"HunyuanVideoTransformer3DModel"} | 
| 91 |  | -        super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | 
| 92 |  | -     | 
| 93 | 95 |     def test_gradient_checkpointing_is_applied(self): | 
| 94 | 96 |         expected_set = {"HunyuanVideoTransformer3DModel"} | 
| 95 | 97 |         super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | 
| 96 | 98 | 
 | 
| 97 | 99 |     @require_torch_gpu | 
| 98 | 100 |     @require_torch_2 | 
| 99 | 101 |     @is_torch_compile | 
| 100 |  | -    @slow     | 
|  | 102 | +    @slow | 
| 101 | 103 |     def test_torch_compile_recompilation_and_graph_break(self): | 
| 102 | 104 |         torch._dynamo.reset() | 
| 103 | 105 |         init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
| @@ -191,7 +193,7 @@ def test_torch_compile_recompilation_and_graph_break(self): | 
| 191 | 193 |         with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): | 
| 192 | 194 |             _ = model(**inputs_dict) | 
| 193 | 195 |             _ = model(**inputs_dict) | 
| 194 |  | -             | 
|  | 196 | + | 
| 195 | 197 | 
 | 
| 196 | 198 | class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | 
| 197 | 199 |     model_class = HunyuanVideoTransformer3DModel | 
| @@ -257,7 +259,7 @@ def test_output(self): | 
| 257 | 259 |     def test_gradient_checkpointing_is_applied(self): | 
| 258 | 260 |         expected_set = {"HunyuanVideoTransformer3DModel"} | 
| 259 | 261 |         super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | 
| 260 |  | -         | 
|  | 262 | + | 
| 261 | 263 |     @require_torch_gpu | 
| 262 | 264 |     @require_torch_2 | 
| 263 | 265 |     @is_torch_compile | 
| @@ -336,7 +338,7 @@ def prepare_init_args_and_inputs_for_common(self): | 
| 336 | 338 | 
 | 
| 337 | 339 |     def test_output(self): | 
| 338 | 340 |         super().test_output(expected_output_shape=(1, *self.output_shape)) | 
| 339 |  | -     | 
|  | 341 | + | 
| 340 | 342 |     def test_gradient_checkpointing_is_applied(self): | 
| 341 | 343 |         expected_set = {"HunyuanVideoTransformer3DModel"} | 
| 342 | 344 |         super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | 
|  | 
0 commit comments