diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 2a61d6613525..33c876535871 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -78,9 +78,7 @@ def create_flux_ip_adapter_state_dict(model): return ip_state_dict -class FluxTransformerTests( - ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase -): +class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel main_input_name = "hidden_states" # We override the items here because the transformer under consideration is small. @@ -169,3 +167,17 @@ def test_deprecated_inputs_img_txt_ids_3d(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"FluxTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = FluxTransformer2DModel + + def prepare_init_args_and_inputs_for_common(self): + return FluxTransformerTests().prepare_init_args_and_inputs_for_common() + + +class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): + model_class = FluxTransformer2DModel + + def prepare_init_args_and_inputs_for_common(self): + return FluxTransformerTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py index 5c83d22ab6aa..112acd0658f5 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_video.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py @@ -28,7 +28,7 @@ enable_full_determinism() -class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): +class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): model_class = HunyuanVideoTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True @@ -93,7 +93,14 @@ def test_gradient_checkpointing_is_applied(self): super().test_gradient_checkpointing_is_applied(expected_set=expected_set) -class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): +class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = HunyuanVideoTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() + + +class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): model_class = HunyuanVideoTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True @@ -161,7 +168,14 @@ def test_gradient_checkpointing_is_applied(self): super().test_gradient_checkpointing_is_applied(expected_set=expected_set) -class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): +class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = HunyuanVideoTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() + + +class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): model_class = HunyuanVideoTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True @@ -227,9 +241,14 @@ def test_gradient_checkpointing_is_applied(self): super().test_gradient_checkpointing_is_applied(expected_set=expected_set) -class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests( - ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase -): +class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = HunyuanVideoTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() + + +class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): model_class = HunyuanVideoTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True @@ -295,3 +314,10 @@ def test_output(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"HunyuanVideoTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = HunyuanVideoTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py index 8649ce97a52e..e624c83b44e5 100644 --- a/tests/models/transformers/test_models_transformer_ltx.py +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -26,7 +26,7 @@ enable_full_determinism() -class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): +class LTXTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = LTXVideoTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True @@ -81,3 +81,10 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"LTXVideoTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = LTXVideoTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return LTXTransformerTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py index 4eadb892364a..682289c6c7f9 100644 --- a/tests/models/transformers/test_models_transformer_wan.py +++ b/tests/models/transformers/test_models_transformer_wan.py @@ -28,7 +28,7 @@ enable_full_determinism() -class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): +class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase): model_class = WanTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True @@ -82,3 +82,10 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"WanTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = WanTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return WanTransformer3DTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 24d944bbf979..ab0dcbc1de11 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -355,9 +355,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): return custom_diffusion_attn_procs -class UNet2DConditionModelTests( - ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase -): +class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet2DConditionModel main_input_name = "sample" # We override the items here because the unet under consideration is small. @@ -1147,6 +1145,20 @@ def test_save_attn_procs_raise_warning(self): assert "Using the `save_attn_procs()` method has been deprecated" in warning_message +class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = UNet2DConditionModel + + def prepare_init_args_and_inputs_for_common(self): + return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common() + + +class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): + model_class = UNet2DConditionModel + + def prepare_init_args_and_inputs_for_common(self): + return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common() + + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase): def get_file_format(self, seed, shape):