Skip to content

Commit 171152f

Browse files
committed
support sage attention with cp
1 parent e76fc94 commit 171152f

File tree

1 file changed

+77
-17
lines changed

1 file changed

+77
-17
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def _(
556556
# ===== Autograd functions =====
557557

558558

559-
class _cudnn_attention(torch.autograd.Function):
559+
class _cudnn_attention_af(torch.autograd.Function):
560560
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
561561
# forward declaration:
562562
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
@@ -614,7 +614,7 @@ def forward(
614614
def backward(
615615
ctx: torch.autograd.function.FunctionCtx,
616616
grad_out: torch.Tensor,
617-
*args: torch.Tensor,
617+
*args,
618618
):
619619
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
620620
grad_out = grad_out.transpose(1, 2).contiguous()
@@ -644,7 +644,7 @@ def backward(
644644

645645

646646
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
647-
class _flash_attention_2(torch.autograd.Function):
647+
class _flash_attention_2_af(torch.autograd.Function):
648648
@staticmethod
649649
def forward(
650650
ctx: torch.autograd.function.FunctionCtx,
@@ -707,7 +707,7 @@ def forward(
707707
def backward(
708708
ctx: torch.autograd.function.FunctionCtx,
709709
grad_out: torch.Tensor,
710-
*args: torch.Tensor,
710+
*args,
711711
):
712712
query, key, value, out, lse, rng_state = ctx.saved_tensors
713713
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
@@ -741,6 +741,51 @@ def backward(
741741
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
742742

743743

744+
class _sage_attention_af(torch.autograd.Function):
745+
@staticmethod
746+
def forward(
747+
ctx: torch.autograd.function.FunctionCtx,
748+
query: torch.Tensor,
749+
key: torch.Tensor,
750+
value: torch.Tensor,
751+
attn_mask: Optional[torch.Tensor] = None,
752+
dropout_p: float = 0.0,
753+
is_causal: bool = False,
754+
scale: Optional[float] = None,
755+
enable_gqa: bool = False,
756+
return_lse: bool = False,
757+
):
758+
if attn_mask is not None:
759+
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
760+
if dropout_p > 0.0:
761+
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
762+
if enable_gqa:
763+
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
764+
765+
out = sageattn(
766+
q=query,
767+
k=key,
768+
v=value,
769+
tensor_layout="NHD",
770+
is_causal=is_causal,
771+
sm_scale=scale,
772+
return_lse=return_lse,
773+
)
774+
lse = None
775+
if return_lse:
776+
out, lse, *_ = out
777+
778+
return (out, lse) if return_lse else out
779+
780+
@staticmethod
781+
def backward(
782+
ctx: torch.autograd.function.FunctionCtx,
783+
grad_out: torch.Tensor,
784+
*args,
785+
):
786+
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
787+
788+
744789
# ===== Context parallel =====
745790

746791

@@ -799,7 +844,7 @@ def forward(
799844
def backward(
800845
ctx: torch.autograd.function.FunctionCtx,
801846
grad_out: torch.Tensor,
802-
*args: torch.Tensor,
847+
*args,
803848
):
804849
raise NotImplementedError("Backward pass is not implemented for TemplatedRingAttention.")
805850

@@ -854,7 +899,7 @@ def forward(
854899
def backward(
855900
ctx: torch.autograd.function.FunctionCtx,
856901
grad_out: torch.Tensor,
857-
*args: torch.Tensor,
902+
*args,
858903
):
859904
raise NotImplementedError("Backward pass is not implemented for TemplatedUlyssesAttention.")
860905

@@ -927,7 +972,7 @@ def _flash_attention(
927972
out, lse, *_ = out
928973
else:
929974
out = _templated_context_parallel_attention(
930-
query, key, value, None, dropout_p, is_causal, scale, False, return_lse, op=_flash_attention_2
975+
query, key, value, None, dropout_p, is_causal, scale, False, return_lse, op=_flash_attention_2_af
931976
)
932977
if return_lse:
933978
out, lse = out
@@ -1191,7 +1236,7 @@ def _native_cudnn_attention(
11911236
out = out.permute(0, 2, 1, 3)
11921237
else:
11931238
out = _templated_context_parallel_attention(
1194-
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op=_cudnn_attention
1239+
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op=_cudnn_attention_af
11951240
)
11961241
if return_lse:
11971242
out, lse = out
@@ -1356,6 +1401,7 @@ def _native_xla_attention(
13561401
@_AttentionBackendRegistry.register(
13571402
AttentionBackendName.SAGE,
13581403
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1404+
supports_context_parallel=True,
13591405
)
13601406
def _sage_attention(
13611407
query: torch.Tensor,
@@ -1365,15 +1411,29 @@ def _sage_attention(
13651411
scale: Optional[float] = None,
13661412
return_lse: bool = False,
13671413
) -> torch.Tensor:
1368-
return sageattn(
1369-
q=query,
1370-
k=key,
1371-
v=value,
1372-
tensor_layout="NHD",
1373-
is_causal=is_causal,
1374-
sm_scale=scale,
1375-
return_lse=return_lse,
1376-
)
1414+
parallel_config = _AttentionBackendRegistry._parallel_config
1415+
1416+
lse = None
1417+
if parallel_config is None:
1418+
out = sageattn(
1419+
q=query,
1420+
k=key,
1421+
v=value,
1422+
tensor_layout="NHD",
1423+
is_causal=is_causal,
1424+
sm_scale=scale,
1425+
return_lse=return_lse,
1426+
)
1427+
if return_lse:
1428+
out, lse, *_ = out
1429+
else:
1430+
out = _templated_context_parallel_attention(
1431+
query, key, value, None, 0.0, is_causal, scale, False, return_lse, op=_sage_attention_af
1432+
)
1433+
if return_lse:
1434+
out, lse = out
1435+
1436+
return (out, lse) if return_lse else out
13771437

13781438

13791439
@_AttentionBackendRegistry.register(

0 commit comments

Comments
 (0)