Skip to content

Commit a0b276d

Browse files
committed
fix more tests
1 parent bc64f12 commit a0b276d

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020

21+
from ..models.attention import AttentionModuleMixin
2122
from ..models.attention_processor import Attention, MochiAttention
2223
from ..utils import logging
2324
from .hooks import HookRegistry, ModelHook
@@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
227228
config.spatial_attention_block_skip_range = 2
228229

229230
for name, submodule in module.named_modules():
230-
if not isinstance(submodule, _ATTENTION_CLASSES):
231+
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
231232
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
232233
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
233234
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.

tests/pipelines/chroma/test_pipeline_chroma_img2img.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,7 @@
88
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
99
from diffusers.utils.testing_utils import floats_tensor, torch_device
1010

11-
from ..test_pipelines_common import (
12-
FluxIPAdapterTesterMixin,
13-
PipelineTesterMixin,
14-
check_qkv_fusion_matches_attn_procs_length,
15-
check_qkv_fusion_processors_exist,
16-
)
11+
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
1712

1813

1914
class ChromaImg2ImgPipelineFastTests(
@@ -129,12 +124,10 @@ def test_fused_qkv_projections(self):
129124
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
130125
# to the pipeline level.
131126
pipe.transformer.fuse_qkv_projections()
132-
assert check_qkv_fusion_processors_exist(pipe.transformer), (
133-
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
127+
self.assertTrue(
128+
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
129+
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
134130
)
135-
assert check_qkv_fusion_matches_attn_procs_length(
136-
pipe.transformer, pipe.transformer.original_attn_processors
137-
), "Something wrong with the attention processors concerning the fused QKV projections."
138131

139132
inputs = self.get_dummy_inputs(device)
140133
image = pipe(**inputs).images

0 commit comments

Comments
 (0)