Skip to content

Commit a820bfd

Browse files
committed
handle backward correctly
1 parent f4c1b4e commit a820bfd

File tree

1 file changed

+38
-45
lines changed

1 file changed

+38
-45
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 38 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,7 @@ def _cudnn_attention_forward_op(
616616
scale: Optional[float] = None,
617617
enable_gqa: bool = False,
618618
return_lse: bool = False,
619+
_save_ctx: bool = True,
619620
):
620621
if enable_gqa:
621622
raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
@@ -625,9 +626,9 @@ def _cudnn_attention_forward_op(
625626
# Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
626627
# if the input tensors are not contiguous.
627628
query = query.transpose(1, 2).contiguous()
628-
tensors_to_save += (query, key, value)
629629
key = key.transpose(1, 2).contiguous()
630630
value = value.transpose(1, 2).contiguous()
631+
tensors_to_save += (query, key, value)
631632

632633
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
633634
torch.ops.aten._scaled_dot_product_cudnn_attention(
@@ -644,13 +645,14 @@ def _cudnn_attention_forward_op(
644645
)
645646

646647
tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
647-
ctx.save_for_backward(*tensors_to_save)
648-
ctx.dropout_p = dropout_p
649-
ctx.is_causal = is_causal
650-
ctx.scale = scale
651-
ctx.attn_mask = attn_mask
652-
ctx.max_q = max_q
653-
ctx.max_k = max_k
648+
if _save_ctx:
649+
ctx.save_for_backward(*tensors_to_save)
650+
ctx.dropout_p = dropout_p
651+
ctx.is_causal = is_causal
652+
ctx.scale = scale
653+
ctx.attn_mask = attn_mask
654+
ctx.max_q = max_q
655+
ctx.max_k = max_k
654656

655657
out = out.transpose(1, 2).contiguous()
656658
if lse is not None:
@@ -666,8 +668,7 @@ def _cudnn_attention_backward_op(
666668
*args,
667669
**kwargs,
668670
):
669-
saved_tensors = ctx.to_save
670-
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = saved_tensors
671+
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
671672

672673
grad_out = grad_out.transpose(1, 2).contiguous()
673674
key = key.transpose(1, 2).contiguous()
@@ -709,6 +710,7 @@ def _flash_attention_forward_op(
709710
scale: Optional[float] = None,
710711
enable_gqa: bool = False,
711712
return_lse: bool = False,
713+
_save_ctx: bool = True,
712714
):
713715
if attn_mask is not None:
714716
raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
@@ -746,14 +748,15 @@ def _flash_attention_forward_op(
746748
)
747749
lse = lse.permute(0, 2, 1)
748750

749-
ctx.save_for_backward(query, key, value, out, lse, rng_state)
750-
ctx.dropout_p = dropout_p
751-
ctx.scale = scale
752-
ctx.is_causal = is_causal
753-
ctx.window_size = window_size
754-
ctx.softcap = softcap
755-
ctx.alibi_slopes = alibi_slopes
756-
ctx.deterministic = deterministic
751+
if _save_ctx:
752+
ctx.save_for_backward(query, key, value, out, lse, rng_state)
753+
ctx.dropout_p = dropout_p
754+
ctx.scale = scale
755+
ctx.is_causal = is_causal
756+
ctx.window_size = window_size
757+
ctx.softcap = softcap
758+
ctx.alibi_slopes = alibi_slopes
759+
ctx.deterministic = deterministic
757760

758761
return (out, lse) if return_lse else out
759762

@@ -764,8 +767,7 @@ def _flash_attention_backward_op(
764767
*args,
765768
**kwargs,
766769
):
767-
saved_tensors = ctx.to_save
768-
query, key, value, out, lse, rng_state = saved_tensors
770+
query, key, value, out, lse, rng_state = ctx.saved_tensors
769771
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
770772

771773
lse_d = _wrapped_flash_attn_backward( # noqa: F841
@@ -808,6 +810,7 @@ def _sage_attention_forward_op(
808810
scale: Optional[float] = None,
809811
enable_gqa: bool = False,
810812
return_lse: bool = False,
813+
_save_ctx: bool = True,
811814
):
812815
if attn_mask is not None:
813816
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
@@ -830,8 +833,6 @@ def _sage_attention_forward_op(
830833
out, lse, *_ = out
831834
lse = lse.permute(0, 2, 1)
832835

833-
ctx.save_for_backward(query, key, value, out, lse)
834-
835836
return (out, lse) if return_lse else out
836837

837838

@@ -892,15 +893,10 @@ def forward(
892893
next_rank = (rank + 1) % world_size
893894
prev_out = prev_lse = None
894895

895-
ctx.save_for_backward(query, key, value)
896-
ctx.dropout_p = dropout_p
897-
ctx.is_causal = is_causal
898-
ctx.scale = scale
899-
ctx.enable_gqa = enable_gqa
900-
ctx.return_lse = return_lse
901896
ctx.forward_op = forward_op
902897
ctx.backward_op = backward_op
903-
ctx.op_ctx = torch.autograd.function.FunctionCtx()
898+
ctx.q_shape = query.shape
899+
ctx.kv_shape = key.shape
904900

905901
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
906902
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
@@ -915,7 +911,7 @@ def forward(
915911
next_rank = (next_rank + 1) % world_size
916912

917913
out, lse = forward_op(
918-
ctx.op_ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True
914+
ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True, _save_ctx=i == 0
919915
)
920916

921917
if parallel_config.convert_to_fp32:
@@ -947,14 +943,13 @@ def backward(
947943
next_rank = (rank + 1) % world_size
948944
next_ranks = list(range(1, world_size)) + [0]
949945

950-
query, key, value = ctx.saved_tensors
951-
952-
accum_dtype = torch.float32 if parallel_config.convert_to_fp32 else query.dtype
953-
grad_query = torch.zeros_like(query, dtype=accum_dtype)
954-
grad_key = torch.zeros_like(key, dtype=accum_dtype)
955-
grad_value = torch.zeros_like(value, dtype=accum_dtype)
946+
accum_dtype = torch.float32 if parallel_config.convert_to_fp32 else grad_out.dtype
947+
grad_query = torch.zeros(ctx.q_shape, dtype=accum_dtype, device=grad_out.device)
948+
grad_key = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device)
949+
grad_value = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device)
956950
next_grad_kv = None
957951

952+
query, key, value, *_ = ctx.saved_tensors
958953
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
959954
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
960955
kv_buffer = kv_buffer.chunk(world_size)
@@ -967,12 +962,7 @@ def backward(
967962
value = kv[key_numel:].reshape_as(value)
968963
next_rank = (next_rank + 1) % world_size
969964

970-
saved_tensors = list(ctx.op_ctx.to_save)
971-
saved_tensors[1] = key
972-
saved_tensors[2] = value
973-
ctx.op_ctx.to_save = tuple(saved_tensors)
974-
975-
grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx.op_ctx, grad_out)
965+
grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out)
976966

977967
if i > 0:
978968
grad_kv_buffer = _wait_tensor(next_grad_kv)
@@ -988,6 +978,8 @@ def backward(
988978
grad_kv_buffer = torch.cat([grad_key.flatten(), grad_value.flatten()]).contiguous()
989979
next_grad_kv = funcol.permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh.get_group())
990980

981+
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
982+
991983
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
992984

993985

@@ -1014,7 +1006,6 @@ def forward(
10141006

10151007
ctx.forward_op = forward_op
10161008
ctx.backward_op = backward_op
1017-
ctx.op_ctx = torch.autograd.function.FunctionCtx()
10181009

10191010
B, S_Q_LOCAL, H, D = query.shape
10201011
_, S_KV_LOCAL, _, _ = key.shape
@@ -1025,7 +1016,9 @@ def forward(
10251016
query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
10261017
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
10271018

1028-
out = forward_op(ctx.op_ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
1019+
out = forward_op(
1020+
ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, _save_ctx=True
1021+
)
10291022
if return_lse:
10301023
out, lse, *_ = out
10311024

@@ -1060,7 +1053,7 @@ def backward(
10601053
grad_out = _all_to_all_single(grad_out, group)
10611054
grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
10621055

1063-
grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx.op_ctx, grad_out)
1056+
grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out)
10641057

10651058
grad_query, grad_key, grad_value = (
10661059
x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()

0 commit comments

Comments
 (0)