@@ -616,6 +616,7 @@ def _cudnn_attention_forward_op(
616616    scale : Optional [float ] =  None ,
617617    enable_gqa : bool  =  False ,
618618    return_lse : bool  =  False ,
619+     _save_ctx : bool  =  True ,
619620):
620621    if  enable_gqa :
621622        raise  ValueError ("`enable_gqa` is not yet supported for cuDNN attention." )
@@ -625,9 +626,9 @@ def _cudnn_attention_forward_op(
625626    # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results 
626627    # if the input tensors are not contiguous. 
627628    query  =  query .transpose (1 , 2 ).contiguous ()
628-     tensors_to_save  +=  (query , key , value )
629629    key  =  key .transpose (1 , 2 ).contiguous ()
630630    value  =  value .transpose (1 , 2 ).contiguous ()
631+     tensors_to_save  +=  (query , key , value )
631632
632633    out , lse , cum_seq_q , cum_seq_k , max_q , max_k , philox_seed , philox_offset , debug_attn_mask  =  (
633634        torch .ops .aten ._scaled_dot_product_cudnn_attention (
@@ -644,13 +645,14 @@ def _cudnn_attention_forward_op(
644645    )
645646
646647    tensors_to_save  +=  (out , lse , cum_seq_q , cum_seq_k , philox_seed , philox_offset )
647-     ctx .save_for_backward (* tensors_to_save )
648-     ctx .dropout_p  =  dropout_p 
649-     ctx .is_causal  =  is_causal 
650-     ctx .scale  =  scale 
651-     ctx .attn_mask  =  attn_mask 
652-     ctx .max_q  =  max_q 
653-     ctx .max_k  =  max_k 
648+     if  _save_ctx :
649+         ctx .save_for_backward (* tensors_to_save )
650+         ctx .dropout_p  =  dropout_p 
651+         ctx .is_causal  =  is_causal 
652+         ctx .scale  =  scale 
653+         ctx .attn_mask  =  attn_mask 
654+         ctx .max_q  =  max_q 
655+         ctx .max_k  =  max_k 
654656
655657    out  =  out .transpose (1 , 2 ).contiguous ()
656658    if  lse  is  not None :
@@ -666,8 +668,7 @@ def _cudnn_attention_backward_op(
666668    * args ,
667669    ** kwargs ,
668670):
669-     saved_tensors  =  ctx .to_save 
670-     query , key , value , out , lse , cum_seq_q , cum_seq_k , philox_seed , philox_offset  =  saved_tensors 
671+     query , key , value , out , lse , cum_seq_q , cum_seq_k , philox_seed , philox_offset  =  ctx .saved_tensors 
671672
672673    grad_out  =  grad_out .transpose (1 , 2 ).contiguous ()
673674    key  =  key .transpose (1 , 2 ).contiguous ()
@@ -709,6 +710,7 @@ def _flash_attention_forward_op(
709710    scale : Optional [float ] =  None ,
710711    enable_gqa : bool  =  False ,
711712    return_lse : bool  =  False ,
713+     _save_ctx : bool  =  True ,
712714):
713715    if  attn_mask  is  not None :
714716        raise  ValueError ("`attn_mask` is not yet supported for flash-attn 2." )
@@ -746,14 +748,15 @@ def _flash_attention_forward_op(
746748        )
747749        lse  =  lse .permute (0 , 2 , 1 )
748750
749-     ctx .save_for_backward (query , key , value , out , lse , rng_state )
750-     ctx .dropout_p  =  dropout_p 
751-     ctx .scale  =  scale 
752-     ctx .is_causal  =  is_causal 
753-     ctx .window_size  =  window_size 
754-     ctx .softcap  =  softcap 
755-     ctx .alibi_slopes  =  alibi_slopes 
756-     ctx .deterministic  =  deterministic 
751+     if  _save_ctx :
752+         ctx .save_for_backward (query , key , value , out , lse , rng_state )
753+         ctx .dropout_p  =  dropout_p 
754+         ctx .scale  =  scale 
755+         ctx .is_causal  =  is_causal 
756+         ctx .window_size  =  window_size 
757+         ctx .softcap  =  softcap 
758+         ctx .alibi_slopes  =  alibi_slopes 
759+         ctx .deterministic  =  deterministic 
757760
758761    return  (out , lse ) if  return_lse  else  out 
759762
@@ -764,8 +767,7 @@ def _flash_attention_backward_op(
764767    * args ,
765768    ** kwargs ,
766769):
767-     saved_tensors  =  ctx .to_save 
768-     query , key , value , out , lse , rng_state  =  saved_tensors 
770+     query , key , value , out , lse , rng_state  =  ctx .saved_tensors 
769771    grad_query , grad_key , grad_value  =  torch .empty_like (query ), torch .empty_like (key ), torch .empty_like (value )
770772
771773    lse_d  =  _wrapped_flash_attn_backward (  # noqa: F841 
@@ -808,6 +810,7 @@ def _sage_attention_forward_op(
808810    scale : Optional [float ] =  None ,
809811    enable_gqa : bool  =  False ,
810812    return_lse : bool  =  False ,
813+     _save_ctx : bool  =  True ,
811814):
812815    if  attn_mask  is  not None :
813816        raise  ValueError ("`attn_mask` is not yet supported for Sage attention." )
@@ -830,8 +833,6 @@ def _sage_attention_forward_op(
830833        out , lse , * _  =  out 
831834        lse  =  lse .permute (0 , 2 , 1 )
832835
833-     ctx .save_for_backward (query , key , value , out , lse )
834- 
835836    return  (out , lse ) if  return_lse  else  out 
836837
837838
@@ -892,15 +893,10 @@ def forward(
892893        next_rank  =  (rank  +  1 ) %  world_size 
893894        prev_out  =  prev_lse  =  None 
894895
895-         ctx .save_for_backward (query , key , value )
896-         ctx .dropout_p  =  dropout_p 
897-         ctx .is_causal  =  is_causal 
898-         ctx .scale  =  scale 
899-         ctx .enable_gqa  =  enable_gqa 
900-         ctx .return_lse  =  return_lse 
901896        ctx .forward_op  =  forward_op 
902897        ctx .backward_op  =  backward_op 
903-         ctx .op_ctx  =  torch .autograd .function .FunctionCtx ()
898+         ctx .q_shape  =  query .shape 
899+         ctx .kv_shape  =  key .shape 
904900
905901        kv_buffer  =  torch .cat ([key .flatten (), value .flatten ()]).contiguous ()
906902        kv_buffer  =  funcol .all_gather_tensor (kv_buffer , gather_dim = 0 , group = ring_mesh .get_group ())
@@ -915,7 +911,7 @@ def forward(
915911                next_rank  =  (next_rank  +  1 ) %  world_size 
916912
917913            out , lse  =  forward_op (
918-                 ctx . op_ctx , query , key , value , attn_mask , dropout_p , is_causal , scale , enable_gqa , True 
914+                 ctx , query , key , value , attn_mask , dropout_p , is_causal , scale , enable_gqa , True ,  _save_ctx = i   ==   0 
919915            )
920916
921917            if  parallel_config .convert_to_fp32 :
@@ -947,14 +943,13 @@ def backward(
947943        next_rank  =  (rank  +  1 ) %  world_size 
948944        next_ranks  =  list (range (1 , world_size )) +  [0 ]
949945
950-         query , key , value  =  ctx .saved_tensors 
951- 
952-         accum_dtype  =  torch .float32  if  parallel_config .convert_to_fp32  else  query .dtype 
953-         grad_query  =  torch .zeros_like (query , dtype = accum_dtype )
954-         grad_key  =  torch .zeros_like (key , dtype = accum_dtype )
955-         grad_value  =  torch .zeros_like (value , dtype = accum_dtype )
946+         accum_dtype  =  torch .float32  if  parallel_config .convert_to_fp32  else  grad_out .dtype 
947+         grad_query  =  torch .zeros (ctx .q_shape , dtype = accum_dtype , device = grad_out .device )
948+         grad_key  =  torch .zeros (ctx .kv_shape , dtype = accum_dtype , device = grad_out .device )
949+         grad_value  =  torch .zeros (ctx .kv_shape , dtype = accum_dtype , device = grad_out .device )
956950        next_grad_kv  =  None 
957951
952+         query , key , value , * _  =  ctx .saved_tensors 
958953        kv_buffer  =  torch .cat ([key .flatten (), value .flatten ()]).contiguous ()
959954        kv_buffer  =  funcol .all_gather_tensor (kv_buffer , gather_dim = 0 , group = ring_mesh .get_group ())
960955        kv_buffer  =  kv_buffer .chunk (world_size )
@@ -967,12 +962,7 @@ def backward(
967962                value  =  kv [key_numel :].reshape_as (value )
968963                next_rank  =  (next_rank  +  1 ) %  world_size 
969964
970-             saved_tensors  =  list (ctx .op_ctx .to_save )
971-             saved_tensors [1 ] =  key 
972-             saved_tensors [2 ] =  value 
973-             ctx .op_ctx .to_save  =  tuple (saved_tensors )
974- 
975-             grad_query_op , grad_key_op , grad_value_op , * _  =  ctx .backward_op (ctx .op_ctx , grad_out )
965+             grad_query_op , grad_key_op , grad_value_op , * _  =  ctx .backward_op (ctx , grad_out )
976966
977967            if  i  >  0 :
978968                grad_kv_buffer  =  _wait_tensor (next_grad_kv )
@@ -988,6 +978,8 @@ def backward(
988978                grad_kv_buffer  =  torch .cat ([grad_key .flatten (), grad_value .flatten ()]).contiguous ()
989979                next_grad_kv  =  funcol .permute_tensor (grad_kv_buffer , next_ranks , group = ring_mesh .get_group ())
990980
981+         grad_query , grad_key , grad_value  =  (x .to (grad_out .dtype ) for  x  in  (grad_query , grad_key , grad_value ))
982+ 
991983        return  grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None 
992984
993985
@@ -1014,7 +1006,6 @@ def forward(
10141006
10151007        ctx .forward_op  =  forward_op 
10161008        ctx .backward_op  =  backward_op 
1017-         ctx .op_ctx  =  torch .autograd .function .FunctionCtx ()
10181009
10191010        B , S_Q_LOCAL , H , D  =  query .shape 
10201011        _ , S_KV_LOCAL , _ , _  =  key .shape 
@@ -1025,7 +1016,9 @@ def forward(
10251016        query , key , value  =  (_all_to_all_single (x , group ) for  x  in  (query , key , value ))
10261017        query , key , value  =  (x .flatten (0 , 1 ).permute (1 , 0 , 2 , 3 ).contiguous () for  x  in  (query , key , value ))
10271018
1028-         out  =  forward_op (ctx .op_ctx , query , key , value , attn_mask , dropout_p , is_causal , scale , enable_gqa , return_lse )
1019+         out  =  forward_op (
1020+             ctx , query , key , value , attn_mask , dropout_p , is_causal , scale , enable_gqa , return_lse , _save_ctx = True 
1021+         )
10291022        if  return_lse :
10301023            out , lse , * _  =  out 
10311024
@@ -1060,7 +1053,7 @@ def backward(
10601053        grad_out  =  _all_to_all_single (grad_out , group )
10611054        grad_out  =  grad_out .flatten (0 , 1 ).permute (1 , 0 , 2 , 3 ).contiguous ()
10621055
1063-         grad_query_op , grad_key_op , grad_value_op , * _  =  ctx .backward_op (ctx . op_ctx , grad_out )
1056+         grad_query_op , grad_key_op , grad_value_op , * _  =  ctx .backward_op (ctx , grad_out )
10641057
10651058        grad_query , grad_key , grad_value  =  (
10661059            x .reshape (B , world_size , S_LOCAL , H_LOCAL , D ).permute (1 , 3 , 0 , 2 , 4 ).contiguous ()
0 commit comments