|  | 
| 28 | 28 | enable_full_determinism() | 
| 29 | 29 | 
 | 
| 30 | 30 | 
 | 
| 31 |  | -class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): | 
|  | 31 | +class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | 
| 32 | 32 |     model_class = HunyuanVideoTransformer3DModel | 
| 33 | 33 |     main_input_name = "hidden_states" | 
| 34 | 34 |     uses_custom_attn_processor = True | 
| @@ -93,7 +93,14 @@ def test_gradient_checkpointing_is_applied(self): | 
| 93 | 93 |         super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | 
| 94 | 94 | 
 | 
| 95 | 95 | 
 | 
| 96 |  | -class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): | 
|  | 96 | +class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): | 
|  | 97 | +    model_class = HunyuanVideoTransformer3DModel | 
|  | 98 | + | 
|  | 99 | +    def prepare_init_args_and_inputs_for_common(self): | 
|  | 100 | +        return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() | 
|  | 101 | + | 
|  | 102 | + | 
|  | 103 | +class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | 
| 97 | 104 |     model_class = HunyuanVideoTransformer3DModel | 
| 98 | 105 |     main_input_name = "hidden_states" | 
| 99 | 106 |     uses_custom_attn_processor = True | 
| @@ -161,7 +168,14 @@ def test_gradient_checkpointing_is_applied(self): | 
| 161 | 168 |         super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | 
| 162 | 169 | 
 | 
| 163 | 170 | 
 | 
| 164 |  | -class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): | 
|  | 171 | +class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase): | 
|  | 172 | +    model_class = HunyuanVideoTransformer3DModel | 
|  | 173 | + | 
|  | 174 | +    def prepare_init_args_and_inputs_for_common(self): | 
|  | 175 | +        return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() | 
|  | 176 | + | 
|  | 177 | + | 
|  | 178 | +class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | 
| 165 | 179 |     model_class = HunyuanVideoTransformer3DModel | 
| 166 | 180 |     main_input_name = "hidden_states" | 
| 167 | 181 |     uses_custom_attn_processor = True | 
| @@ -227,9 +241,14 @@ def test_gradient_checkpointing_is_applied(self): | 
| 227 | 241 |         super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | 
| 228 | 242 | 
 | 
| 229 | 243 | 
 | 
| 230 |  | -class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests( | 
| 231 |  | -    ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase | 
| 232 |  | -): | 
|  | 244 | +class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase): | 
|  | 245 | +    model_class = HunyuanVideoTransformer3DModel | 
|  | 246 | + | 
|  | 247 | +    def prepare_init_args_and_inputs_for_common(self): | 
|  | 248 | +        return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() | 
|  | 249 | + | 
|  | 250 | + | 
|  | 251 | +class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): | 
| 233 | 252 |     model_class = HunyuanVideoTransformer3DModel | 
| 234 | 253 |     main_input_name = "hidden_states" | 
| 235 | 254 |     uses_custom_attn_processor = True | 
| @@ -295,3 +314,10 @@ def test_output(self): | 
| 295 | 314 |     def test_gradient_checkpointing_is_applied(self): | 
| 296 | 315 |         expected_set = {"HunyuanVideoTransformer3DModel"} | 
| 297 | 316 |         super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | 
|  | 317 | + | 
|  | 318 | + | 
|  | 319 | +class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase): | 
|  | 320 | +    model_class = HunyuanVideoTransformer3DModel | 
|  | 321 | + | 
|  | 322 | +    def prepare_init_args_and_inputs_for_common(self): | 
|  | 323 | +        return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() | 
0 commit comments