Skip to content

Commit 253ef7e

Browse files
committed
add fast test
1 parent 7938b42 commit 253ef7e

File tree

3 files changed

+143
-1
lines changed

3 files changed

+143
-1
lines changed

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import torch
1919

2020
from diffusers import FluxTransformer2DModel
21+
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
22+
from diffusers.models.embeddings import ImageProjection
2123
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
2224

2325
from ..test_modeling_common import ModelTesterMixin
@@ -26,6 +28,56 @@
2628
enable_full_determinism()
2729

2830

31+
def create_flux_ip_adapter_state_dict(model):
32+
# "ip_adapter" (cross-attention weights)
33+
ip_cross_attn_state_dict = {}
34+
key_id = 0
35+
36+
for name in model.attn_processors.keys():
37+
if name.startswith("single_transformer_blocks"):
38+
continue
39+
40+
joint_attention_dim = model.config["joint_attention_dim"]
41+
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
42+
sd = FluxIPAdapterJointAttnProcessor2_0(
43+
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
44+
).state_dict()
45+
ip_cross_attn_state_dict.update(
46+
{
47+
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
48+
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
49+
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
50+
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
51+
}
52+
)
53+
54+
key_id += 1
55+
56+
# "image_proj" (ImageProjection layer weights)
57+
58+
image_projection = ImageProjection(
59+
cross_attention_dim=model.config["joint_attention_dim"],
60+
image_embed_dim=model.config["pooled_projection_dim"],
61+
num_image_text_embeds=4,
62+
)
63+
64+
ip_image_projection_state_dict = {}
65+
sd = image_projection.state_dict()
66+
ip_image_projection_state_dict.update(
67+
{
68+
"proj.weight": sd["image_embeds.weight"],
69+
"proj.bias": sd["image_embeds.bias"],
70+
"norm.weight": sd["norm.weight"],
71+
"norm.bias": sd["norm.bias"],
72+
}
73+
)
74+
75+
del sd
76+
ip_state_dict = {}
77+
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
78+
return ip_state_dict
79+
80+
2981
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
3082
model_class = FluxTransformer2DModel
3183
main_input_name = "hidden_states"

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
)
1717

1818
from ..test_pipelines_common import (
19+
FluxIPAdapterTesterMixin,
1920
PipelineTesterMixin,
2021
check_qkv_fusion_matches_attn_procs_length,
2122
check_qkv_fusion_processors_exist,
2223
)
2324

2425

25-
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
26+
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
2627
pipeline_class = FluxPipeline
2728
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
2829
batch_params = frozenset(["prompt"])

tests/pipelines/test_pipelines_common.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
get_autoencoder_tiny_config,
5555
get_consistency_vae_config,
5656
)
57+
from ..models.transformers.test_models_transformer_flux import create_flux_ip_adapter_state_dict
5758
from ..models.unets.test_models_unet_2d_condition import (
5859
create_ip_adapter_faceid_state_dict,
5960
create_ip_adapter_state_dict,
@@ -483,6 +484,94 @@ def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4):
483484
)
484485

485486

487+
class FluxIPAdapterTesterMixin:
488+
"""
489+
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
490+
It provides a set of common tests for pipelines that support IP Adapters.
491+
"""
492+
493+
def test_pipeline_signature(self):
494+
parameters = inspect.signature(self.pipeline_class.__call__).parameters
495+
496+
assert issubclass(self.pipeline_class, FluxIPAdapterTesterMixin)
497+
self.assertIn(
498+
"ip_adapter_image",
499+
parameters,
500+
"`ip_adapter_image` argument must be supported by the `__call__` method",
501+
)
502+
self.assertIn(
503+
"ip_adapter_image_embeds",
504+
parameters,
505+
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method",
506+
)
507+
508+
def _get_dummy_image_embeds(self, image_embed_dim: int = 768):
509+
return torch.randn((1, 1, image_embed_dim), device=torch_device)
510+
511+
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
512+
inputs["negative_prompt"] = ""
513+
inputs["true_cfg_scale"] = 4.0
514+
inputs["output_type"] = "np"
515+
inputs["return_dict"] = False
516+
return inputs
517+
518+
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
519+
r"""Tests for IP-Adapter.
520+
521+
The following scenarios are tested:
522+
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
523+
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
524+
"""
525+
# Raising the tolerance for this test when it's run on a CPU because we
526+
# compare against static slices and that can be shaky (with a VVVV low probability).
527+
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
528+
529+
components = self.get_dummy_components()
530+
pipe = self.pipeline_class(**components).to(torch_device)
531+
pipe.set_progress_bar_config(disable=None)
532+
image_embed_dim = pipe.transformer.config.pooled_projection_dim
533+
534+
# forward pass without ip adapter
535+
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
536+
if expected_pipe_slice is None:
537+
output_without_adapter = pipe(**inputs)[0]
538+
else:
539+
output_without_adapter = expected_pipe_slice
540+
541+
adapter_state_dict = create_flux_ip_adapter_state_dict(pipe.transformer)
542+
pipe.transformer._load_ip_adapter_weights(adapter_state_dict)
543+
544+
# forward pass with single ip adapter, but scale=0 which should have no effect
545+
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
546+
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
547+
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
548+
pipe.set_ip_adapter_scale(0.0)
549+
output_without_adapter_scale = pipe(**inputs)[0]
550+
if expected_pipe_slice is not None:
551+
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
552+
553+
# forward pass with single ip adapter, but with scale of adapter weights
554+
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
555+
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
556+
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
557+
pipe.set_ip_adapter_scale(42.0)
558+
output_with_adapter_scale = pipe(**inputs)[0]
559+
if expected_pipe_slice is not None:
560+
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
561+
562+
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
563+
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
564+
565+
self.assertLess(
566+
max_diff_without_adapter_scale,
567+
expected_max_diff,
568+
"Output without ip-adapter must be same as normal inference",
569+
)
570+
self.assertGreater(
571+
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
572+
)
573+
574+
486575
class PipelineLatentTesterMixin:
487576
"""
488577
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.

0 commit comments

Comments
 (0)