Skip to content

Commit 2d424e0

Browse files
committed
use joint_seq_lens
1 parent 6a549d4 commit 2d424e0

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def __call__(
361361
# If an encoder_hidden_states_mask is provided, turn it into a broadcastable attention mask.
362362
# The encoder_hidden_states_mask is expected to have 1.0 for valid tokens and 0.0 for padding.
363363
# We convert it to a boolean mask where True means "attend" and False means "mask out" (don't attend).
364+
joint_seq_lens = None
364365
if encoder_hidden_states_mask is not None and attention_mask is None:
365366
batch_size, image_seq_len = hidden_states.shape[:2]
366367
text_seq_len = encoder_hidden_states.shape[1]
@@ -385,6 +386,12 @@ def __call__(
385386
joint_attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1)
386387
attention_mask = joint_attention_mask[:, None, None, :]
387388

389+
# For varlen flash attention, we need the JOINT sequence lengths (text + image), not just text
390+
if text_seq_lens is not None:
391+
# text_seq_lens contains per-sample text lengths
392+
# Add the image sequence length to get total joint sequence length
393+
joint_seq_lens = text_seq_lens + image_seq_len
394+
388395
# Compute joint attention
389396
joint_hidden_states = dispatch_attention_fn(
390397
joint_query,
@@ -395,7 +402,7 @@ def __call__(
395402
is_causal=False,
396403
backend=self._attention_backend,
397404
parallel_config=self._parallel_config,
398-
seq_lens=text_seq_lens,
405+
seq_lens=joint_seq_lens,
399406
)
400407

401408
# Reshape back

0 commit comments

Comments
 (0)