Skip to content

Commit 66ce9cc

Browse files
committed
refacotr
1 parent bb443f9 commit 66ce9cc

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,8 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
314314
):
315315
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
316316
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
317-
cu_seqlens_k = torch.cumsum(seqlens_q, dim=0, dtype=torch.int32)
318-
cu_seqlens_q = torch.cumsum(seqlens_k, dim=0, dtype=torch.int32)
319-
cu_seqlens_q = torch.nn.functional.pad(cu_seqlens_q, (1, 0))
320-
cu_seqlens_k = torch.nn.functional.pad(cu_seqlens_k, (1, 0))
317+
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
318+
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0))
321319
max_seqlen_q = seqlens_q.max().item()
322320
max_seqlen_k = seqlens_k.max().item()
323321
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)

0 commit comments

Comments
 (0)