Skip to content

Commit ff21b7f

Browse files
committed
improve test
1 parent ecabd2a commit ff21b7f

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
FluxIPAdapterTesterMixin,
2929
PipelineTesterMixin,
3030
PyramidAttentionBroadcastTesterMixin,
31-
check_qkv_fusion_matches_attn_procs_length,
32-
check_qkv_fusion_processors_exist,
31+
check_qkv_fused_layers_exist,
3332
)
3433

3534

@@ -171,12 +170,10 @@ def test_fused_qkv_projections(self):
171170
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
172171
# to the pipeline level.
173172
pipe.transformer.fuse_qkv_projections()
174-
assert check_qkv_fusion_processors_exist(pipe.transformer), (
175-
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
173+
self.assertTrue(
174+
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
175+
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
176176
)
177-
assert check_qkv_fusion_matches_attn_procs_length(
178-
pipe.transformer, pipe.transformer.original_attn_processors
179-
), "Something wrong with the attention processors concerning the fused QKV projections."
180177

181178
inputs = self.get_dummy_inputs(device)
182179
image = pipe(**inputs).images

tests/pipelines/test_pipelines_common.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
3838
from diffusers.image_processor import VaeImageProcessor
3939
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
40+
from diffusers.models.attention import AttentionModuleMixin
4041
from diffusers.models.attention_processor import AttnProcessor
4142
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
4243
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
@@ -98,6 +99,20 @@ def check_qkv_fusion_processors_exist(model):
9899
return all(p.startswith("Fused") for p in proc_names)
99100

100101

102+
def check_qkv_fused_layers_exist(model, layer_names):
103+
is_fused_submodules = []
104+
for submodule in model.modules():
105+
if not isinstance(submodule, AttentionModuleMixin):
106+
continue
107+
is_fused_attribute_set = submodule.fused_projections
108+
is_fused_layer = True
109+
for layer in layer_names:
110+
is_fused_layer = is_fused_layer and getattr(submodule, layer, None) is not None
111+
is_fused = is_fused_attribute_set and is_fused_layer
112+
is_fused_submodules.append(is_fused)
113+
return all(is_fused_submodules)
114+
115+
101116
class SDFunctionTesterMixin:
102117
"""
103118
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.

0 commit comments

Comments
 (0)