File tree Expand file tree Collapse file tree 2 files changed +19
-7
lines changed Expand file tree Collapse file tree 2 files changed +19
-7
lines changed Original file line number Diff line number Diff line change 2828 FluxIPAdapterTesterMixin ,
2929 PipelineTesterMixin ,
3030 PyramidAttentionBroadcastTesterMixin ,
31- check_qkv_fusion_matches_attn_procs_length ,
32- check_qkv_fusion_processors_exist ,
31+ check_qkv_fused_layers_exist ,
3332)
3433
3534
@@ -171,12 +170,10 @@ def test_fused_qkv_projections(self):
171170 # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
172171 # to the pipeline level.
173172 pipe .transformer .fuse_qkv_projections ()
174- assert check_qkv_fusion_processors_exist (pipe .transformer ), (
175- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
173+ self .assertTrue (
174+ check_qkv_fused_layers_exist (pipe .transformer , ["to_qkv" ]),
175+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused." ),
176176 )
177- assert check_qkv_fusion_matches_attn_procs_length (
178- pipe .transformer , pipe .transformer .original_attn_processors
179- ), "Something wrong with the attention processors concerning the fused QKV projections."
180177
181178 inputs = self .get_dummy_inputs (device )
182179 image = pipe (** inputs ).images
Original file line number Diff line number Diff line change 3737from diffusers .hooks .pyramid_attention_broadcast import PyramidAttentionBroadcastHook
3838from diffusers .image_processor import VaeImageProcessor
3939from diffusers .loaders import FluxIPAdapterMixin , IPAdapterMixin
40+ from diffusers .models .attention import AttentionModuleMixin
4041from diffusers .models .attention_processor import AttnProcessor
4142from diffusers .models .controlnets .controlnet_xs import UNetControlNetXSModel
4243from diffusers .models .unets .unet_3d_condition import UNet3DConditionModel
@@ -98,6 +99,20 @@ def check_qkv_fusion_processors_exist(model):
9899 return all (p .startswith ("Fused" ) for p in proc_names )
99100
100101
102+ def check_qkv_fused_layers_exist (model , layer_names ):
103+ is_fused_submodules = []
104+ for submodule in model .modules ():
105+ if not isinstance (submodule , AttentionModuleMixin ):
106+ continue
107+ is_fused_attribute_set = submodule .fused_projections
108+ is_fused_layer = True
109+ for layer in layer_names :
110+ is_fused_layer = is_fused_layer and getattr (submodule , layer , None ) is not None
111+ is_fused = is_fused_attribute_set and is_fused_layer
112+ is_fused_submodules .append (is_fused )
113+ return all (is_fused_submodules )
114+
115+
101116class SDFunctionTesterMixin :
102117 """
103118 This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
You can’t perform that action at this time.
0 commit comments