Skip to content

Commit d4a380d

Browse files
committed
propagate
1 parent 9b8015c commit d4a380d

File tree

3 files changed

+48
-8
lines changed

3 files changed

+48
-8
lines changed

tests/models/transformers/test_models_transformer_hunyuan_video.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
enable_full_determinism()
2929

3030

31-
class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
31+
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
3232
model_class = HunyuanVideoTransformer3DModel
3333
main_input_name = "hidden_states"
3434
uses_custom_attn_processor = True
@@ -93,7 +93,14 @@ def test_gradient_checkpointing_is_applied(self):
9393
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
9494

9595

96-
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
96+
class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
97+
model_class = HunyuanVideoTransformer3DModel
98+
99+
def prepare_init_args_and_inputs_for_common(self):
100+
return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
101+
102+
103+
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
97104
model_class = HunyuanVideoTransformer3DModel
98105
main_input_name = "hidden_states"
99106
uses_custom_attn_processor = True
@@ -161,7 +168,14 @@ def test_gradient_checkpointing_is_applied(self):
161168
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
162169

163170

164-
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
171+
class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
172+
model_class = HunyuanVideoTransformer3DModel
173+
174+
def prepare_init_args_and_inputs_for_common(self):
175+
return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
176+
177+
178+
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
165179
model_class = HunyuanVideoTransformer3DModel
166180
main_input_name = "hidden_states"
167181
uses_custom_attn_processor = True
@@ -227,9 +241,14 @@ def test_gradient_checkpointing_is_applied(self):
227241
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
228242

229243

230-
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
231-
ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase
232-
):
244+
class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
245+
model_class = HunyuanVideoTransformer3DModel
246+
247+
def prepare_init_args_and_inputs_for_common(self):
248+
return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
249+
250+
251+
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
233252
model_class = HunyuanVideoTransformer3DModel
234253
main_input_name = "hidden_states"
235254
uses_custom_attn_processor = True
@@ -295,3 +314,10 @@ def test_output(self):
295314
def test_gradient_checkpointing_is_applied(self):
296315
expected_set = {"HunyuanVideoTransformer3DModel"}
297316
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
317+
318+
319+
class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase):
320+
model_class = HunyuanVideoTransformer3DModel
321+
322+
def prepare_init_args_and_inputs_for_common(self):
323+
return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()

tests/models/transformers/test_models_transformer_ltx.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
enable_full_determinism()
2727

2828

29-
class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
29+
class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
3030
model_class = LTXVideoTransformer3DModel
3131
main_input_name = "hidden_states"
3232
uses_custom_attn_processor = True
@@ -81,3 +81,10 @@ def prepare_init_args_and_inputs_for_common(self):
8181
def test_gradient_checkpointing_is_applied(self):
8282
expected_set = {"LTXVideoTransformer3DModel"}
8383
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
84+
85+
86+
class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
87+
model_class = LTXVideoTransformer3DModel
88+
89+
def prepare_init_args_and_inputs_for_common(self):
90+
return LTXTransformerTests().prepare_init_args_and_inputs_for_common()

tests/models/transformers/test_models_transformer_wan.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
enable_full_determinism()
2929

3030

31-
class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
31+
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
3232
model_class = WanTransformer3DModel
3333
main_input_name = "hidden_states"
3434
uses_custom_attn_processor = True
@@ -82,3 +82,10 @@ def prepare_init_args_and_inputs_for_common(self):
8282
def test_gradient_checkpointing_is_applied(self):
8383
expected_set = {"WanTransformer3DModel"}
8484
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
85+
86+
87+
class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
88+
model_class = WanTransformer3DModel
89+
90+
def prepare_init_args_and_inputs_for_common(self):
91+
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)