Skip to content

Commit bb443f9

Browse files
committed
update
1 parent 638cc03 commit bb443f9

File tree

2 files changed

+33
-70
lines changed

2 files changed

+33
-70
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 32 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -314,13 +314,13 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
314314
):
315315
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
316316
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
317-
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
318-
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
319-
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
320-
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
317+
cu_seqlens_k = torch.cumsum(seqlens_q, dim=0, dtype=torch.int32)
318+
cu_seqlens_q = torch.cumsum(seqlens_k, dim=0, dtype=torch.int32)
319+
cu_seqlens_q = torch.nn.functional.pad(cu_seqlens_q, (1, 0))
320+
cu_seqlens_k = torch.nn.functional.pad(cu_seqlens_k, (1, 0))
321321
max_seqlen_q = seqlens_q.max().item()
322322
max_seqlen_k = seqlens_k.max().item()
323-
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
323+
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
324324

325325

326326
def _prepare_for_flash_attn_or_sage_varlen_with_mask(
@@ -331,13 +331,11 @@ def _prepare_for_flash_attn_or_sage_varlen_with_mask(
331331
):
332332
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
333333
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
334-
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
335-
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
336-
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
337-
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
334+
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
335+
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0))
338336
max_seqlen_q = seqlens_q.max().item()
339337
max_seqlen_k = seqlens_k.max().item()
340-
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
338+
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
341339

342340

343341
def _prepare_for_flash_attn_or_sage_varlen(
@@ -496,30 +494,18 @@ def _flash_varlen_attention(
496494
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
497495

498496
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
499-
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
500-
_prepare_for_flash_attn_or_sage_varlen(
501-
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
502-
)
497+
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
498+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
503499
)
504-
else:
505-
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
506-
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
507-
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
508-
509-
key_valid, value_valid = [], []
510-
for b in range(batch_size):
511-
valid_len = seqlens_k[b]
512-
key_valid.append(key[b, :valid_len])
513-
value_valid.append(value[b, :valid_len])
514500

515-
query_packed = query.flatten(0, 1)
516-
key_packed = torch.cat(key_valid, dim=0)
517-
value_packed = torch.cat(value_valid, dim=0)
501+
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
502+
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
518503

504+
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
519505
out = flash_attn_varlen_func(
520-
q=query_packed,
521-
k=key_packed,
522-
v=value_packed,
506+
q=query,
507+
k=key,
508+
v=value,
523509
cu_seqlens_q=cu_seqlens_q,
524510
cu_seqlens_k=cu_seqlens_k,
525511
max_seqlen_q=max_seqlen_q,
@@ -601,30 +587,18 @@ def _flash_varlen_attention_3(
601587
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
602588

603589
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
604-
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
605-
_prepare_for_flash_attn_or_sage_varlen(
606-
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
607-
)
590+
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
591+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
608592
)
609-
else:
610-
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
611-
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
612-
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
613-
614-
key_valid, value_valid = [], []
615-
for b in range(batch_size):
616-
valid_len = seqlens_k[b]
617-
key_valid.append(key[b, :valid_len])
618-
value_valid.append(value[b, :valid_len])
619593

620-
query_packed = query.flatten(0, 1)
621-
key_packed = torch.cat(key_valid, dim=0)
622-
value_packed = torch.cat(value_valid, dim=0)
594+
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
595+
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
623596

597+
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
624598
out, lse, *_ = flash_attn_3_varlen_func(
625-
q=query_packed,
626-
k=key_packed,
627-
v=value_packed,
599+
q=query,
600+
k=key,
601+
v=value,
628602
cu_seqlens_q=cu_seqlens_q,
629603
cu_seqlens_k=cu_seqlens_k,
630604
max_seqlen_q=max_seqlen_q,
@@ -958,30 +932,18 @@ def _sage_varlen_attention(
958932
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
959933

960934
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
961-
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
962-
_prepare_for_flash_attn_or_sage_varlen(
963-
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
964-
)
935+
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
936+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
965937
)
966-
else:
967-
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
968-
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
969-
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
970938

971-
key_valid, value_valid = [], []
972-
for b in range(batch_size):
973-
valid_len = seqlens_k[b]
974-
key_valid.append(key[b, :valid_len])
975-
value_valid.append(value[b, :valid_len])
976-
977-
query_packed = query.flatten(0, 1)
978-
key_packed = torch.cat(key_valid, dim=0)
979-
value_packed = torch.cat(value_valid, dim=0)
939+
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
940+
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
980941

942+
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
981943
out = sageattn_varlen(
982-
q=query_packed,
983-
k=key_packed,
984-
v=value_packed,
944+
q=query,
945+
k=key,
946+
v=value,
985947
cu_seqlens_q=cu_seqlens_q,
986948
cu_seqlens_k=cu_seqlens_k,
987949
max_seqlen_q=max_seqlen_q,

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def __call__(
263263
return hidden_states
264264

265265

266+
@maybe_allow_in_graph
266267
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
267268
_default_processor_cls = FluxAttnProcessor
268269
_available_processors = [

0 commit comments

Comments
 (0)