Skip to content

Commit 0f1a4e0

Browse files
committed
update
1 parent aa29af8 commit 0f1a4e0

File tree

5 files changed

+22
-16
lines changed

5 files changed

+22
-16
lines changed

tests/models/testing_utils/common.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def test_output(self, expected_output_shape=None):
260260

261261
assert output is not None, "Model output is None"
262262
assert (
263-
output.shape == expected_output_shape
263+
output[0].shape == expected_output_shape or self.output_shape
264264
), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
265265

266266
def test_outputs_equivalence(self):
@@ -302,15 +302,11 @@ def recursive_check(tuple_object, dict_object):
302302

303303
recursive_check(outputs_tuple, outputs_dict)
304304

305-
def test_model_config_to_json_string(self):
306-
model = self.model_class(**self.get_init_dict())
307-
308-
json_string = model.config.to_json_string()
309-
assert isinstance(json_string, str), "Config to_json_string should return a string"
310-
assert len(json_string) > 0, "JSON string should not be empty"
311-
312305
@require_accelerator
313-
@pytest.mark.skipif(torch_device not in ["cuda", "xpu"])
306+
@pytest.mark.skipif(
307+
torch_device not in ["cuda", "xpu"],
308+
reason="float16 and bfloat16 can only be use for inference with an accelerator",
309+
)
314310
def test_from_save_pretrained_float16_bfloat16(self):
315311
model = self.model_class(**self.get_init_dict())
316312
model.to(torch_device)

tests/models/testing_utils/ip_adapter.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def test_load_ip_adapter(self):
100100
init_dict = self.get_init_dict()
101101
inputs_dict = self.get_dummy_inputs()
102102
model = self.model_class(**init_dict).to(torch_device)
103+
self.prepare_model(model)
103104

104105
torch.manual_seed(0)
105106
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
@@ -128,9 +129,10 @@ def test_ip_adapter_scale(self):
128129
init_dict = self.get_init_dict()
129130
inputs_dict = self.get_dummy_inputs()
130131
model = self.model_class(**init_dict).to(torch_device)
132+
# self.prepare_model(model)
131133

132134
# Create and load dummy IP adapter state dict
133-
ip_adapter_state_dict = create_ip_adapter_state_dict(model)
135+
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
134136
model._load_ip_adapter_weights([ip_adapter_state_dict])
135137

136138
# Test scale = 0.0 (no effect)
@@ -151,12 +153,13 @@ def test_ip_adapter_scale(self):
151153
def test_unload_ip_adapter(self):
152154
init_dict = self.get_init_dict()
153155
model = self.model_class(**init_dict).to(torch_device)
156+
self.prepare_model(model)
154157

155158
# Save original processors
156159
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
157160

158161
# Create and load IP adapter
159-
ip_adapter_state_dict = create_ip_adapter_state_dict(model)
162+
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
160163
model._load_ip_adapter_weights([ip_adapter_state_dict])
161164
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be set"
162165

@@ -172,9 +175,10 @@ def test_ip_adapter_save_load(self):
172175
init_dict = self.get_init_dict()
173176
inputs_dict = self.get_dummy_inputs()
174177
model = self.model_class(**init_dict).to(torch_device)
178+
self.prepare_model(model)
175179

176180
# Create and load IP adapter
177-
ip_adapter_state_dict = self.create_ip_adapter_state_dict()
181+
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
178182
model._load_ip_adapter_weights([ip_adapter_state_dict])
179183

180184
torch.manual_seed(0)

tests/models/testing_utils/quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@
4040
require_accelerator,
4141
require_bitsandbytes_version_greater,
4242
require_gguf_version_greater_or_equal,
43+
require_modelopt_version_greater_or_equal,
4344
require_quanto,
4445
require_torchao_version_greater_or_equal,
45-
require_modelopt_version_greater_or_equal,
4646
torch_device,
4747
)
4848

tests/models/transformers/test_models_transformer_flux_.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from diffusers import FluxTransformer2DModel
2121
from diffusers.models.embeddings import ImageProjection
22+
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
2223
from diffusers.utils.torch_utils import randn_tensor
2324

2425
from ...testing_utils import enable_full_determinism, torch_device
@@ -87,11 +88,11 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
8788

8889
@property
8990
def input_shape(self) -> tuple[int, int]:
90-
return (1, 16, 4)
91+
return (16, 4)
9192

9293
@property
9394
def output_shape(self) -> tuple[int, int]:
94-
return (1, 16, 4)
95+
return (16, 4)
9596

9697

9798
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
@@ -148,6 +149,11 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM
148149
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
149150
"""IP Adapter tests for Flux Transformer."""
150151

152+
def prepare_model(self, model):
153+
joint_attention_dim = model.config["joint_attention_dim"]
154+
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
155+
model.set_attn_processor(FluxIPAdapterAttnProcessor(hidden_size, joint_attention_dim, scale=1.0))
156+
151157
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
152158
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
153159

tests/testing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
is_flax_available,
3838
is_gguf_available,
3939
is_kernels_available,
40-
is_nvidia_modelopt_available,
4140
is_note_seq_available,
41+
is_nvidia_modelopt_available,
4242
is_onnx_available,
4343
is_opencv_available,
4444
is_optimum_quanto_available,

0 commit comments

Comments
 (0)