Skip to content

Commit c141520

Browse files
committed
fight more tests
1 parent a0b276d commit c141520

File tree

4 files changed

+13
-31
lines changed

4 files changed

+13
-31
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def __call__(
242242
key = apply_rotary_emb(key, image_rotary_emb)
243243

244244
hidden_states = torch.nn.functional.scaled_dot_product_attention(
245-
query, key, value, dropout_p=0.0, is_causal=False
245+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
246246
)
247247
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
248248
hidden_states = hidden_states.to(query.dtype)

tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@
1616
)
1717
from diffusers.utils.torch_utils import randn_tensor
1818

19-
from ..test_pipelines_common import (
20-
PipelineTesterMixin,
21-
check_qkv_fusion_matches_attn_procs_length,
22-
check_qkv_fusion_processors_exist,
23-
)
19+
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
2420

2521

2622
class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -170,12 +166,10 @@ def test_fused_qkv_projections(self):
170166
original_image_slice = image[0, -3:, -3:, -1]
171167

172168
pipe.transformer.fuse_qkv_projections()
173-
assert check_qkv_fusion_processors_exist(pipe.transformer), (
174-
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
169+
self.assertTrue(
170+
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
171+
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
175172
)
176-
assert check_qkv_fusion_matches_attn_procs_length(
177-
pipe.transformer, pipe.transformer.original_attn_processors
178-
), "Something wrong with the attention processors concerning the fused QKV projections."
179173

180174
inputs = self.get_dummy_inputs(device)
181175
image = pipe(**inputs).images

tests/pipelines/flux/test_pipeline_flux_control.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@
88
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
99
from diffusers.utils.testing_utils import torch_device
1010

11-
from ..test_pipelines_common import (
12-
PipelineTesterMixin,
13-
check_qkv_fusion_matches_attn_procs_length,
14-
check_qkv_fusion_processors_exist,
15-
)
11+
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
1612

1713

1814
class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -140,12 +136,10 @@ def test_fused_qkv_projections(self):
140136
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
141137
# to the pipeline level.
142138
pipe.transformer.fuse_qkv_projections()
143-
assert check_qkv_fusion_processors_exist(pipe.transformer), (
144-
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
139+
self.assertTrue(
140+
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
141+
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
145142
)
146-
assert check_qkv_fusion_matches_attn_procs_length(
147-
pipe.transformer, pipe.transformer.original_attn_processors
148-
), "Something wrong with the attention processors concerning the fused QKV projections."
149143

150144
inputs = self.get_dummy_inputs(device)
151145
image = pipe(**inputs).images

tests/pipelines/flux/test_pipeline_flux_control_inpaint.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515
torch_device,
1616
)
1717

18-
from ..test_pipelines_common import (
19-
PipelineTesterMixin,
20-
check_qkv_fusion_matches_attn_procs_length,
21-
check_qkv_fusion_processors_exist,
22-
)
18+
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
2319

2420

2521
class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -134,12 +130,10 @@ def test_fused_qkv_projections(self):
134130
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
135131
# to the pipeline level.
136132
pipe.transformer.fuse_qkv_projections()
137-
assert check_qkv_fusion_processors_exist(pipe.transformer), (
138-
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
133+
self.assertTrue(
134+
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
135+
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
139136
)
140-
assert check_qkv_fusion_matches_attn_procs_length(
141-
pipe.transformer, pipe.transformer.original_attn_processors
142-
), "Something wrong with the attention processors concerning the fused QKV projections."
143137

144138
inputs = self.get_dummy_inputs(device)
145139
image = pipe(**inputs).images

0 commit comments

Comments
 (0)