Skip to content

Commit 6a549d4

Browse files
committed
add seq_lens to dispatch_attention_fn
1 parent 18efdde commit 6a549d4

File tree

4 files changed

+68
-23
lines changed

4 files changed

+68
-23
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def dispatch_attention_fn(
305305
*,
306306
backend: Optional[AttentionBackendName] = None,
307307
parallel_config: Optional["ParallelConfig"] = None,
308+
seq_lens: Optional[torch.Tensor] = None,
308309
) -> torch.Tensor:
309310
attention_kwargs = attention_kwargs or {}
310311

@@ -327,6 +328,8 @@ def dispatch_attention_fn(
327328
**attention_kwargs,
328329
"_parallel_config": parallel_config,
329330
}
331+
if seq_lens is not None:
332+
kwargs["seq_lens"] = seq_lens
330333
if is_torch_version(">=", "2.5.0"):
331334
kwargs["enable_gqa"] = enable_gqa
332335

@@ -1400,18 +1403,29 @@ def _flash_varlen_attention(
14001403
is_causal: bool = False,
14011404
return_lse: bool = False,
14021405
_parallel_config: Optional["ParallelConfig"] = None,
1406+
seq_lens: Optional[torch.Tensor] = None,
14031407
) -> torch.Tensor:
14041408
batch_size, seq_len_q, _, _ = query.shape
14051409
_, seq_len_kv, _, _ = key.shape
14061410

1407-
if attn_mask is not None:
1408-
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
1411+
if seq_lens is not None:
1412+
seq_lens = seq_lens.to(query.device)
1413+
# use the same lengths for Q and KV
1414+
seqlens_k = seq_lens
1415+
cu_seqlens_q = torch.cat([seq_lens.new_zeros(1), seq_lens.cumsum(0)], dim=0).to(torch.int32)
1416+
cu_seqlens_k = cu_seqlens_q
1417+
max_seqlen_q = int(seq_lens.max().item())
1418+
max_seqlen_k = max_seqlen_q
1419+
attn_mask = None # varlen uses lengths
1420+
else:
1421+
if attn_mask is not None:
1422+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
14091423

1410-
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
1411-
_prepare_for_flash_attn_or_sage_varlen(
1412-
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
1424+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
1425+
_prepare_for_flash_attn_or_sage_varlen(
1426+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
1427+
)
14131428
)
1414-
)
14151429

14161430
key_valid, value_valid = [], []
14171431
for b in range(batch_size):
@@ -1521,18 +1535,28 @@ def _flash_varlen_attention_3(
15211535
is_causal: bool = False,
15221536
return_lse: bool = False,
15231537
_parallel_config: Optional["ParallelConfig"] = None,
1538+
seq_lens: Optional[torch.Tensor] = None,
15241539
) -> torch.Tensor:
15251540
batch_size, seq_len_q, _, _ = query.shape
15261541
_, seq_len_kv, _, _ = key.shape
15271542

1528-
if attn_mask is not None:
1529-
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
1543+
if seq_lens is not None:
1544+
seq_lens = seq_lens.to(query.device)
1545+
seqlens_k = seq_lens
1546+
cu_seqlens_q = torch.cat([seq_lens.new_zeros(1), seq_lens.cumsum(0)], dim=0).to(torch.int32)
1547+
cu_seqlens_k = cu_seqlens_q
1548+
max_seqlen_q = int(seq_lens.max().item())
1549+
max_seqlen_k = max_seqlen_q
1550+
attn_mask = None # varlen uses lengths
1551+
else:
1552+
if attn_mask is not None:
1553+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
15301554

1531-
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
1532-
_prepare_for_flash_attn_or_sage_varlen(
1533-
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
1555+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
1556+
_prepare_for_flash_attn_or_sage_varlen(
1557+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
1558+
)
15341559
)
1535-
)
15361560

15371561
key_valid, value_valid = [], []
15381562
for b in range(batch_size):
@@ -2023,21 +2047,31 @@ def _sage_varlen_attention(
20232047
scale: Optional[float] = None,
20242048
return_lse: bool = False,
20252049
_parallel_config: Optional["ParallelConfig"] = None,
2050+
seq_lens: Optional[torch.Tensor] = None,
20262051
) -> torch.Tensor:
20272052
if return_lse:
20282053
raise ValueError("Sage varlen backend does not support setting `return_lse=True`.")
20292054

20302055
batch_size, seq_len_q, _, _ = query.shape
20312056
_, seq_len_kv, _, _ = key.shape
20322057

2033-
if attn_mask is not None:
2034-
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
2058+
if seq_lens is not None:
2059+
seq_lens = seq_lens.to(query.device)
2060+
seqlens_k = seq_lens
2061+
cu_seqlens_q = torch.cat([seq_lens.new_zeros(1), seq_lens.cumsum(0)], dim=0).to(torch.int32)
2062+
cu_seqlens_k = cu_seqlens_q
2063+
max_seqlen_q = int(seq_lens.max().item())
2064+
max_seqlen_k = max_seqlen_q
2065+
attn_mask = None # varlen uses lengths
2066+
else:
2067+
if attn_mask is not None:
2068+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
20352069

2036-
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
2037-
_prepare_for_flash_attn_or_sage_varlen(
2038-
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
2070+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
2071+
_prepare_for_flash_attn_or_sage_varlen(
2072+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
2073+
)
20392074
)
2040-
)
20412075

20422076
key_valid, value_valid = [], []
20432077
for b in range(batch_size):

src/diffusers/models/controlnets/controlnet_qwenimage.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def forward(
228228
joint_attention_kwargs = joint_attention_kwargs.copy()
229229
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
230230
else:
231+
joint_attention_kwargs = {}
231232
lora_scale = 1.0
232233

233234
if USE_PEFT_BACKEND:
@@ -246,10 +247,13 @@ def forward(
246247
temb = self.time_text_embed(timestep, hidden_states)
247248

248249
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
249-
text_seq_len, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
250+
text_seq_len, text_seq_lens_per_sample, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
250251
encoder_hidden_states, encoder_hidden_states_mask
251252
)
252253

254+
if text_seq_lens_per_sample is not None:
255+
joint_attention_kwargs.setdefault("text_seq_lens", text_seq_lens_per_sample)
256+
253257
image_rotary_emb = self.pos_embed(img_shapes, text_seq_len, device=hidden_states.device)
254258

255259
timestep = timestep.to(hidden_states.dtype)

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def apply_rotary_emb_qwen(
143143

144144
def compute_text_seq_len_from_mask(
145145
encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor]
146-
) -> Tuple[int, Optional[torch.Tensor]]:
146+
) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]:
147147
"""
148148
Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask.
149149
"""
@@ -166,7 +166,7 @@ def compute_text_seq_len_from_mask(
166166
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
167167
rope_text_seq_len = max(text_seq_len, int(per_sample_len.max().item()))
168168

169-
return rope_text_seq_len, encoder_hidden_states_mask
169+
return rope_text_seq_len, per_sample_len, encoder_hidden_states_mask
170170

171171

172172
class QwenTimestepProjEmbeddings(nn.Module):
@@ -308,6 +308,7 @@ def __call__(
308308
encoder_hidden_states_mask: torch.FloatTensor = None,
309309
attention_mask: Optional[torch.FloatTensor] = None,
310310
image_rotary_emb: Optional[torch.Tensor] = None,
311+
text_seq_lens: Optional[torch.Tensor] = None,
311312
) -> torch.FloatTensor:
312313
if encoder_hidden_states is None:
313314
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
@@ -394,6 +395,7 @@ def __call__(
394395
is_causal=False,
395396
backend=self._attention_backend,
396397
parallel_config=self._parallel_config,
398+
seq_lens=text_seq_lens,
397399
)
398400

399401
# Reshape back
@@ -665,6 +667,7 @@ def forward(
665667
attention_kwargs = attention_kwargs.copy()
666668
lora_scale = attention_kwargs.pop("scale", 1.0)
667669
else:
670+
attention_kwargs = {}
668671
lora_scale = 1.0
669672

670673
if USE_PEFT_BACKEND:
@@ -683,10 +686,13 @@ def forward(
683686
encoder_hidden_states = self.txt_in(encoder_hidden_states)
684687

685688
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
686-
text_seq_len, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
689+
text_seq_len, text_seq_lens_per_sample, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
687690
encoder_hidden_states, encoder_hidden_states_mask
688691
)
689692

693+
if text_seq_lens_per_sample is not None:
694+
attention_kwargs.setdefault("text_seq_lens", text_seq_lens_per_sample)
695+
690696
if guidance is not None:
691697
guidance = guidance.to(hidden_states.dtype) * 1000
692698

tests/models/transformers/test_models_transformer_qwenimage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,10 @@ def test_non_contiguous_attention_mask(self):
134134
encoder_hidden_states_mask[:, 3] = 0
135135
encoder_hidden_states_mask[:, 5:] = 0
136136

137-
inferred_rope_len, normalized_mask = compute_text_seq_len_from_mask(
137+
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
138138
inputs["encoder_hidden_states"], encoder_hidden_states_mask
139139
)
140+
self.assertEqual(int(per_sample_len.max().item()), 5)
140141
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
141142
self.assertTrue(normalized_mask.dtype == torch.bool)
142143

0 commit comments

Comments
 (0)