Skip to content

Commit 38a89ed

Browse files
committed
Recover attention_dispatch.py with its origin impl, later would special commit for fa3 compatibility.
1 parent 6c0c059 commit 38a89ed

File tree

1 file changed

+29
-126
lines changed

1 file changed

+29
-126
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 29 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,7 @@
1818
import math
1919
from dataclasses import dataclass
2020
from enum import Enum
21-
from typing import (
22-
TYPE_CHECKING,
23-
Any,
24-
Callable,
25-
Dict,
26-
List,
27-
Literal,
28-
Optional,
29-
Tuple,
30-
Union,
31-
)
21+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
3222

3323
import torch
3424

@@ -78,10 +68,7 @@
7868

7969
if _CAN_USE_FLASH_ATTN:
8070
from flash_attn import flash_attn_func, flash_attn_varlen_func
81-
from flash_attn.flash_attn_interface import (
82-
_wrapped_flash_attn_backward,
83-
_wrapped_flash_attn_forward,
84-
)
71+
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
8572
else:
8673
flash_attn_func = None
8774
flash_attn_varlen_func = None
@@ -90,13 +77,11 @@
9077

9178

9279
if _CAN_USE_FLASH_ATTN_3:
93-
from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward
9480
from flash_attn_interface import flash_attn_func as flash_attn_3_func
9581
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
9682
else:
9783
flash_attn_3_func = None
9884
flash_attn_3_varlen_func = None
99-
flash_attn_3_forward = None
10085

10186
if _CAN_USE_AITER_ATTN:
10287
from aiter import flash_attn_func as aiter_flash_attn_func
@@ -135,9 +120,7 @@
135120

136121

137122
if _CAN_USE_XLA_ATTN:
138-
from torch_xla.experimental.custom_kernel import (
139-
flash_attention as xla_flash_attention,
140-
)
123+
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
141124
else:
142125
xla_flash_attention = None
143126

@@ -280,17 +263,13 @@ class _HubKernelConfig:
280263
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
281264
# TODO: temporary revision for now. Remove when merged upstream into `main`.
282265
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
283-
repo_id="kernels-community/flash-attn3",
284-
function_attr="flash_attn_func",
285-
revision="fake-ops-return-probs",
266+
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
286267
)
287268
}
288269

289270

290271
@contextlib.contextmanager
291-
def attention_backend(
292-
backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE,
293-
):
272+
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
294273
"""
295274
Context manager to set the active attention backend.
296275
"""
@@ -435,10 +414,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
435414
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
436415
)
437416

438-
elif backend in [
439-
AttentionBackendName._FLASH_3,
440-
AttentionBackendName._FLASH_VARLEN_3,
441-
]:
417+
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
442418
if not _CAN_USE_FLASH_ATTN_3:
443419
raise RuntimeError(
444420
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
@@ -510,11 +486,7 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
510486
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
511487
max_seqlen_q = seqlens_q.max().item()
512488
max_seqlen_k = seqlens_k.max().item()
513-
return (
514-
(seqlens_q, seqlens_k),
515-
(cu_seqlens_q, cu_seqlens_k),
516-
(max_seqlen_q, max_seqlen_k),
517-
)
489+
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
518490

519491

520492
def _prepare_for_flash_attn_or_sage_varlen_with_mask(
@@ -531,11 +503,7 @@ def _prepare_for_flash_attn_or_sage_varlen_with_mask(
531503
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
532504
max_seqlen_q = seqlens_q.max().item()
533505
max_seqlen_k = seqlens_k.max().item()
534-
return (
535-
(seqlens_q, seqlens_k),
536-
(cu_seqlens_q, cu_seqlens_k),
537-
(max_seqlen_q, max_seqlen_k),
538-
)
506+
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
539507

540508

541509
def _prepare_for_flash_attn_or_sage_varlen(
@@ -653,42 +621,22 @@ def _wrapped_flash_attn_3(
653621
) -> Tuple[torch.Tensor, torch.Tensor]:
654622
# Hardcoded for now because pytorch does not support tuple/int type hints
655623
window_size = (-1, -1)
656-
max_seqlen_q = q.shape[2]
657-
max_seqlen_k = k.shape[2]
658-
659-
out, lse, *_ = flash_attn_3_forward(
624+
out, lse, *_ = flash_attn_3_func(
660625
q=q,
661626
k=k,
662627
v=v,
663-
k_new=None,
664-
v_new=None,
628+
softmax_scale=softmax_scale,
629+
causal=causal,
665630
qv=qv,
666-
out=None,
667-
cu_seqlens_q=None,
668-
cu_seqlens_k=None,
669-
cu_seqlens_k_new=None,
670-
seqused_q=None,
671-
seqused_k=None,
672-
max_seqlen_q=max_seqlen_q,
673-
max_seqlen_k=max_seqlen_k,
674-
page_table=None,
675-
kv_batch_idx=None,
676-
leftpad_k=None,
677-
rotary_cos=None,
678-
rotary_sin=None,
679-
seqlens_rotary=None,
680631
q_descale=q_descale,
681632
k_descale=k_descale,
682633
v_descale=v_descale,
683-
softmax_scale=softmax_scale,
684-
causal=causal,
685634
window_size=window_size,
686635
attention_chunk=attention_chunk,
687636
softcap=softcap,
688-
rotary_interleaved=True,
689-
scheduler_metadata=None,
690637
num_splits=num_splits,
691638
pack_gqa=pack_gqa,
639+
deterministic=deterministic,
692640
sm_margin=sm_margin,
693641
)
694642
lse = lse.permute(0, 2, 1)
@@ -794,10 +742,7 @@ def _native_attention_backward_op(
794742

795743
grad_out_t = grad_out.permute(0, 2, 1, 3)
796744
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
797-
outputs=out,
798-
inputs=[query_t, key_t, value_t],
799-
grad_outputs=grad_out_t,
800-
retain_graph=False,
745+
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
801746
)
802747

803748
grad_query = grad_query_t.permute(0, 2, 1, 3)
@@ -836,26 +781,18 @@ def _cudnn_attention_forward_op(
836781
value = value.transpose(1, 2).contiguous()
837782
tensors_to_save += (query, key, value)
838783

839-
(
840-
out,
841-
lse,
842-
cum_seq_q,
843-
cum_seq_k,
844-
max_q,
845-
max_k,
846-
philox_seed,
847-
philox_offset,
848-
debug_attn_mask,
849-
) = torch.ops.aten._scaled_dot_product_cudnn_attention(
850-
query=query,
851-
key=key,
852-
value=value,
853-
attn_bias=attn_mask,
854-
compute_log_sumexp=return_lse,
855-
dropout_p=dropout_p,
856-
is_causal=is_causal,
857-
return_debug_mask=False,
858-
scale=scale,
784+
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
785+
torch.ops.aten._scaled_dot_product_cudnn_attention(
786+
query=query,
787+
key=key,
788+
value=value,
789+
attn_bias=attn_mask,
790+
compute_log_sumexp=return_lse,
791+
dropout_p=dropout_p,
792+
is_causal=is_causal,
793+
return_debug_mask=False,
794+
scale=scale,
795+
)
859796
)
860797

861798
tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
@@ -982,11 +919,7 @@ def _flash_attention_backward_op(
982919
**kwargs,
983920
):
984921
query, key, value, out, lse, rng_state = ctx.saved_tensors
985-
grad_query, grad_key, grad_value = (
986-
torch.empty_like(query),
987-
torch.empty_like(key),
988-
torch.empty_like(value),
989-
)
922+
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
990923

991924
lse_d = _wrapped_flash_attn_backward( # noqa: F841
992925
grad_out,
@@ -1210,19 +1143,7 @@ def backward(
12101143

12111144
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
12121145

1213-
return (
1214-
grad_query,
1215-
grad_key,
1216-
grad_value,
1217-
None,
1218-
None,
1219-
None,
1220-
None,
1221-
None,
1222-
None,
1223-
None,
1224-
None,
1225-
)
1146+
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
12261147

12271148

12281149
class TemplatedUlyssesAttention(torch.autograd.Function):
@@ -1317,19 +1238,7 @@ def backward(
13171238
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
13181239
)
13191240

1320-
return (
1321-
grad_query,
1322-
grad_key,
1323-
grad_value,
1324-
None,
1325-
None,
1326-
None,
1327-
None,
1328-
None,
1329-
None,
1330-
None,
1331-
None,
1332-
)
1241+
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
13331242

13341243

13351244
def _templated_context_parallel_attention(
@@ -1677,12 +1586,7 @@ def _native_flex_attention(
16771586
block_mask = attn_mask
16781587
elif is_causal:
16791588
block_mask = flex_attention.create_block_mask(
1680-
_flex_attention_causal_mask_mod,
1681-
batch_size,
1682-
num_heads,
1683-
seq_len_q,
1684-
seq_len_kv,
1685-
query.device,
1589+
_flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device
16861590
)
16871591
elif torch.is_tensor(attn_mask):
16881592
if attn_mask.ndim == 2:
@@ -1702,7 +1606,6 @@ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
17021606

17031607
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
17041608
return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
1705-
17061609
else:
17071610
raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
17081611

0 commit comments

Comments
 (0)