Skip to content

Commit 9b8015c

Browse files
committed
fix how compiler tester mixins are used.
1 parent f46abfe commit 9b8015c

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
7878
return ip_state_dict
7979

8080

81-
class FluxTransformerTests(
82-
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase
83-
):
81+
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
8482
model_class = FluxTransformer2DModel
8583
main_input_name = "hidden_states"
8684
# We override the items here because the transformer under consideration is small.
@@ -169,3 +167,17 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
169167
def test_gradient_checkpointing_is_applied(self):
170168
expected_set = {"FluxTransformer2DModel"}
171169
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
170+
171+
172+
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
173+
model_class = FluxTransformer2DModel
174+
175+
def prepare_init_args_and_inputs_for_common(self):
176+
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
177+
178+
179+
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
180+
model_class = FluxTransformer2DModel
181+
182+
def prepare_init_args_and_inputs_for_common(self):
183+
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,20 @@ def test_save_attn_procs_raise_warning(self):
11471147
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
11481148

11491149

1150+
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
1151+
model_class = UNet2DConditionModel
1152+
1153+
def prepare_init_args_and_inputs_for_common(self):
1154+
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
1155+
1156+
1157+
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
1158+
model_class = UNet2DConditionModel
1159+
1160+
def prepare_init_args_and_inputs_for_common(self):
1161+
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
1162+
1163+
11501164
@slow
11511165
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
11521166
def get_file_format(self, seed, shape):

0 commit comments

Comments
 (0)