Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
return ip_state_dict


class FluxTransformerTests(
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase
):
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
Expand Down Expand Up @@ -169,3 +167,17 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel

def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()


class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel

def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
38 changes: 32 additions & 6 deletions tests/models/transformers/test_models_transformer_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
enable_full_determinism()


class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
Expand Down Expand Up @@ -93,7 +93,14 @@ def test_gradient_checkpointing_is_applied(self):
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel

def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()


class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
Expand Down Expand Up @@ -161,7 +168,14 @@ def test_gradient_checkpointing_is_applied(self):
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel

def prepare_init_args_and_inputs_for_common(self):
return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()


class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
Expand Down Expand Up @@ -227,9 +241,14 @@ def test_gradient_checkpointing_is_applied(self):
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase
):
class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel

def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()


class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
Expand Down Expand Up @@ -295,3 +314,10 @@ def test_output(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel

def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
9 changes: 8 additions & 1 deletion tests/models/transformers/test_models_transformer_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
enable_full_determinism()


class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
Expand Down Expand Up @@ -81,3 +81,10 @@ def prepare_init_args_and_inputs_for_common(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTXVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel

def prepare_init_args_and_inputs_for_common(self):
return LTXTransformerTests().prepare_init_args_and_inputs_for_common()
9 changes: 8 additions & 1 deletion tests/models/transformers/test_models_transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
enable_full_determinism()


class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
Expand Down Expand Up @@ -82,3 +82,10 @@ def prepare_init_args_and_inputs_for_common(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel

def prepare_init_args_and_inputs_for_common(self):
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()
18 changes: 15 additions & 3 deletions tests/models/unets/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs


class UNet2DConditionModelTests(
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
):
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
Expand Down Expand Up @@ -1147,6 +1145,20 @@ def test_save_attn_procs_raise_warning(self):
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message


class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel

def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()


class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel

def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()


@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape):
Expand Down
Loading