@@ -556,7 +556,7 @@ def _(
556556# ===== Autograd functions ===== 
557557
558558
559- class  _cudnn_attention (torch .autograd .Function ):
559+ class  _cudnn_attention_af (torch .autograd .Function ):
560560    # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 
561561    # forward declaration: 
562562    #   aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) 
@@ -614,7 +614,7 @@ def forward(
614614    def  backward (
615615        ctx : torch .autograd .function .FunctionCtx ,
616616        grad_out : torch .Tensor ,
617-         * args :  torch . Tensor ,
617+         * args ,
618618    ):
619619        query , key , value , out , lse , cum_seq_q , cum_seq_k , philox_seed , philox_offset  =  ctx .saved_tensors 
620620        grad_out  =  grad_out .transpose (1 , 2 ).contiguous ()
@@ -644,7 +644,7 @@ def backward(
644644
645645
646646# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 
647- class  _flash_attention_2 (torch .autograd .Function ):
647+ class  _flash_attention_2_af (torch .autograd .Function ):
648648    @staticmethod  
649649    def  forward (
650650        ctx : torch .autograd .function .FunctionCtx ,
@@ -707,7 +707,7 @@ def forward(
707707    def  backward (
708708        ctx : torch .autograd .function .FunctionCtx ,
709709        grad_out : torch .Tensor ,
710-         * args :  torch . Tensor ,
710+         * args ,
711711    ):
712712        query , key , value , out , lse , rng_state  =  ctx .saved_tensors 
713713        grad_query , grad_key , grad_value  =  torch .empty_like (query ), torch .empty_like (key ), torch .empty_like (value )
@@ -741,6 +741,51 @@ def backward(
741741        return  grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None 
742742
743743
744+ class  _sage_attention_af (torch .autograd .Function ):
745+     @staticmethod  
746+     def  forward (
747+         ctx : torch .autograd .function .FunctionCtx ,
748+         query : torch .Tensor ,
749+         key : torch .Tensor ,
750+         value : torch .Tensor ,
751+         attn_mask : Optional [torch .Tensor ] =  None ,
752+         dropout_p : float  =  0.0 ,
753+         is_causal : bool  =  False ,
754+         scale : Optional [float ] =  None ,
755+         enable_gqa : bool  =  False ,
756+         return_lse : bool  =  False ,
757+     ):
758+         if  attn_mask  is  not None :
759+             raise  ValueError ("`attn_mask` is not yet supported for Sage attention." )
760+         if  dropout_p  >  0.0 :
761+             raise  ValueError ("`dropout_p` is not yet supported for Sage attention." )
762+         if  enable_gqa :
763+             raise  ValueError ("`enable_gqa` is not yet supported for Sage attention." )
764+ 
765+         out  =  sageattn (
766+             q = query ,
767+             k = key ,
768+             v = value ,
769+             tensor_layout = "NHD" ,
770+             is_causal = is_causal ,
771+             sm_scale = scale ,
772+             return_lse = return_lse ,
773+         )
774+         lse  =  None 
775+         if  return_lse :
776+             out , lse , * _  =  out 
777+ 
778+         return  (out , lse ) if  return_lse  else  out 
779+ 
780+     @staticmethod  
781+     def  backward (
782+         ctx : torch .autograd .function .FunctionCtx ,
783+         grad_out : torch .Tensor ,
784+         * args ,
785+     ):
786+         raise  NotImplementedError ("Backward pass is not implemented for Sage attention." )
787+ 
788+ 
744789# ===== Context parallel ===== 
745790
746791
@@ -799,7 +844,7 @@ def forward(
799844    def  backward (
800845        ctx : torch .autograd .function .FunctionCtx ,
801846        grad_out : torch .Tensor ,
802-         * args :  torch . Tensor ,
847+         * args ,
803848    ):
804849        raise  NotImplementedError ("Backward pass is not implemented for TemplatedRingAttention." )
805850
@@ -854,7 +899,7 @@ def forward(
854899    def  backward (
855900        ctx : torch .autograd .function .FunctionCtx ,
856901        grad_out : torch .Tensor ,
857-         * args :  torch . Tensor ,
902+         * args ,
858903    ):
859904        raise  NotImplementedError ("Backward pass is not implemented for TemplatedUlyssesAttention." )
860905
@@ -927,7 +972,7 @@ def _flash_attention(
927972            out , lse , * _  =  out 
928973    else :
929974        out  =  _templated_context_parallel_attention (
930-             query , key , value , None , dropout_p , is_causal , scale , False , return_lse , op = _flash_attention_2 
975+             query , key , value , None , dropout_p , is_causal , scale , False , return_lse , op = _flash_attention_2_af 
931976        )
932977        if  return_lse :
933978            out , lse  =  out 
@@ -1191,7 +1236,7 @@ def _native_cudnn_attention(
11911236        out  =  out .permute (0 , 2 , 1 , 3 )
11921237    else :
11931238        out  =  _templated_context_parallel_attention (
1194-             query , key , value , attn_mask , dropout_p , is_causal , scale , enable_gqa , return_lse , op = _cudnn_attention 
1239+             query , key , value , attn_mask , dropout_p , is_causal , scale , enable_gqa , return_lse , op = _cudnn_attention_af 
11951240        )
11961241        if  return_lse :
11971242            out , lse  =  out 
@@ -1356,6 +1401,7 @@ def _native_xla_attention(
13561401@_AttentionBackendRegistry .register ( 
13571402    AttentionBackendName .SAGE , 
13581403    constraints = [_check_device_cuda , _check_qkv_dtype_bf16_or_fp16 , _check_shape ], 
1404+     supports_context_parallel = True , 
13591405) 
13601406def  _sage_attention (
13611407    query : torch .Tensor ,
@@ -1365,15 +1411,29 @@ def _sage_attention(
13651411    scale : Optional [float ] =  None ,
13661412    return_lse : bool  =  False ,
13671413) ->  torch .Tensor :
1368-     return  sageattn (
1369-         q = query ,
1370-         k = key ,
1371-         v = value ,
1372-         tensor_layout = "NHD" ,
1373-         is_causal = is_causal ,
1374-         sm_scale = scale ,
1375-         return_lse = return_lse ,
1376-     )
1414+     parallel_config  =  _AttentionBackendRegistry ._parallel_config 
1415+ 
1416+     lse  =  None 
1417+     if  parallel_config  is  None :
1418+         out  =  sageattn (
1419+             q = query ,
1420+             k = key ,
1421+             v = value ,
1422+             tensor_layout = "NHD" ,
1423+             is_causal = is_causal ,
1424+             sm_scale = scale ,
1425+             return_lse = return_lse ,
1426+         )
1427+         if  return_lse :
1428+             out , lse , * _  =  out 
1429+     else :
1430+         out  =  _templated_context_parallel_attention (
1431+             query , key , value , None , 0.0 , is_causal , scale , False , return_lse , op = _sage_attention_af 
1432+         )
1433+         if  return_lse :
1434+             out , lse  =  out 
1435+ 
1436+     return  (out , lse ) if  return_lse  else  out 
13771437
13781438
13791439@_AttentionBackendRegistry .register ( 
0 commit comments