|
8 | 8 | from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler |
9 | 9 | from diffusers.utils.testing_utils import floats_tensor, torch_device |
10 | 10 |
|
11 | | -from ..test_pipelines_common import ( |
12 | | - FluxIPAdapterTesterMixin, |
13 | | - PipelineTesterMixin, |
14 | | - check_qkv_fusion_matches_attn_procs_length, |
15 | | - check_qkv_fusion_processors_exist, |
16 | | -) |
| 11 | +from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist |
17 | 12 |
|
18 | 13 |
|
19 | 14 | class ChromaImg2ImgPipelineFastTests( |
@@ -129,12 +124,10 @@ def test_fused_qkv_projections(self): |
129 | 124 | # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added |
130 | 125 | # to the pipeline level. |
131 | 126 | pipe.transformer.fuse_qkv_projections() |
132 | | - assert check_qkv_fusion_processors_exist(pipe.transformer), ( |
133 | | - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." |
| 127 | + self.assertTrue( |
| 128 | + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), |
| 129 | + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), |
134 | 130 | ) |
135 | | - assert check_qkv_fusion_matches_attn_procs_length( |
136 | | - pipe.transformer, pipe.transformer.original_attn_processors |
137 | | - ), "Something wrong with the attention processors concerning the fused QKV projections." |
138 | 131 |
|
139 | 132 | inputs = self.get_dummy_inputs(device) |
140 | 133 | image = pipe(**inputs).images |
|
0 commit comments