Skip to content

Commit 4b201df

Browse files
committed
fix flash/sage seqlen preparation when kv len does not match q len (cross attention)
1 parent 03a7630 commit 4b201df

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,15 @@ def _check_shape(
276276

277277

278278
def _prepare_for_flash_attn_or_sage_varlen(
279-
batch_size: int, seq_len_q: int, attn_mask: Optional[torch.Tensor] = None, device: Optional[torch.device] = None
279+
batch_size: int,
280+
seq_len_q: int,
281+
seq_len_kv: int,
282+
attn_mask: Optional[torch.Tensor] = None,
283+
device: Optional[torch.device] = None,
280284
) -> None:
281285
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
282286
if attn_mask is None:
283-
seqlens_k = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
287+
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
284288
else:
285289
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
286290
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
@@ -440,7 +444,9 @@ def _flash_varlen_attention(
440444

441445
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
442446
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
443-
_prepare_for_flash_attn_or_sage_varlen(batch_size, seq_len_q, attn_mask=attn_mask, device=query.device)
447+
_prepare_for_flash_attn_or_sage_varlen(
448+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
449+
)
444450
)
445451
else:
446452
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
@@ -730,7 +736,9 @@ def _sage_varlen_attention(
730736

731737
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
732738
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
733-
_prepare_for_flash_attn_or_sage_varlen(batch_size, seq_len_q, attn_mask=attn_mask, device=query.device)
739+
_prepare_for_flash_attn_or_sage_varlen(
740+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
741+
)
734742
)
735743
else:
736744
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)

0 commit comments

Comments
 (0)