Skip to content

Commit 50c4815

Browse files
committed
fix compile tests
1 parent 35efa06 commit 50c4815

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -560,16 +560,15 @@ def __call__(
560560
f"must match encoder_hidden_states sequence length ({text_seq_len})."
561561
)
562562

563-
# Only create mask if there's actual padding (i.e., some False/0 values)
564-
# When all values are True/1.0, passing attention_mask=None is more efficient for SDPA
563+
# Create joint attention mask
564+
# torch.compile compatible: always create mask when encoder_hidden_states_mask is provided
565565
text_attention_mask = encoder_hidden_states_mask.bool()
566-
if not text_attention_mask.all():
567-
image_attention_mask = torch.ones(
568-
(batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device
569-
)
570-
# Create 2D joint mask [batch_size, text_seq_len + image_seq_len]
571-
# The attention dispatch will normalize this and extract sequence lengths
572-
attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1)
566+
image_attention_mask = torch.ones(
567+
(batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device
568+
)
569+
# Create 2D joint mask [batch_size, text_seq_len + image_seq_len]
570+
# The attention dispatch will normalize this and extract sequence lengths
571+
attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1)
573572

574573
# Compute joint attention
575574
joint_hidden_states = dispatch_attention_fn(

0 commit comments

Comments
 (0)