@@ -571,8 +571,8 @@ def forward(
571571 value : torch .Tensor ,
572572 attn_mask : Optional [torch .Tensor ] = None ,
573573 dropout_p : float = 0.0 ,
574- scale : Optional [float ] = None ,
575574 is_causal : bool = False ,
575+ scale : Optional [float ] = None ,
576576 enable_gqa : bool = False ,
577577 return_lse : bool = False ,
578578 ):
@@ -653,8 +653,8 @@ def forward(
653653 value : torch .Tensor ,
654654 attn_mask : Optional [torch .Tensor ] = None ,
655655 dropout_p : float = 0.0 ,
656- scale : Optional [float ] = None ,
657656 is_causal : bool = False ,
657+ scale : Optional [float ] = None ,
658658 enable_gqa : bool = False ,
659659 return_lse : bool = False ,
660660 ):
@@ -753,8 +753,8 @@ def forward(
753753 value : torch .Tensor ,
754754 attn_mask : Optional [torch .Tensor ],
755755 dropout_p : float ,
756- scale : Optional [float ],
757756 is_causal : bool ,
757+ scale : Optional [float ],
758758 enable_gqa : bool ,
759759 return_lse : bool ,
760760 op : torch .autograd .Function ,
@@ -778,7 +778,7 @@ def forward(
778778 value = kv [key .numel () :].reshape_as (value )
779779 next_rank = (next_rank + 1 ) % world_size
780780
781- out , lse = op .apply (query , key , value , attn_mask , dropout_p , scale , is_causal , enable_gqa , True )
781+ out , lse = op .apply (query , key , value , attn_mask , dropout_p , is_causal , scale , enable_gqa , True )
782782
783783 if parallel_config .convert_to_fp32 :
784784 out = out .to (torch .float32 )
@@ -813,8 +813,8 @@ def forward(
813813 value : torch .Tensor ,
814814 attn_mask : Optional [torch .Tensor ],
815815 dropout_p : float ,
816- scale : Optional [float ],
817816 is_causal : bool ,
817+ scale : Optional [float ],
818818 enable_gqa : bool ,
819819 return_lse : bool ,
820820 op : torch .autograd .Function ,
@@ -833,7 +833,7 @@ def forward(
833833 query , key , value = (funcol .all_to_all_single (x , None , None , group = group ).wait () for x in (query , key , value ))
834834 query , key , value = (x .flatten (0 , 1 ).permute (1 , 0 , 2 , 3 ).contiguous () for x in (query , key , value ))
835835
836- out = op .apply (query , key , value , attn_mask , dropout_p , scale , is_causal , enable_gqa , return_lse )
836+ out = op .apply (query , key , value , attn_mask , dropout_p , is_causal , scale , enable_gqa , return_lse )
837837 if return_lse :
838838 out , lse , * _ = out
839839
@@ -883,14 +883,14 @@ def _templated_context_parallel_attention(
883883 # TODO: add support for unified attention with ring/ulysses degree both being > 1
884884 if parallel_config .ring_degree > 1 :
885885 return TemplatedRingAttention .apply (
886- query , key , value , attn_mask , dropout_p , scale , is_causal , enable_gqa , return_lse , op
886+ query , key , value , attn_mask , dropout_p , is_causal , scale , enable_gqa , return_lse , op
887887 )
888888 elif parallel_config .ulysses_degree > 1 :
889889 return TemplatedUlyssesAttention .apply (
890- query , key , value , attn_mask , dropout_p , scale , is_causal , enable_gqa , return_lse , op
890+ query , key , value , attn_mask , dropout_p , is_causal , scale , enable_gqa , return_lse , op
891891 )
892892 else :
893- return op .apply (query , key , value , attn_mask , dropout_p , scale , is_causal , enable_gqa , return_lse )
893+ return op .apply (query , key , value , attn_mask , dropout_p , is_causal , scale , enable_gqa , return_lse )
894894
895895
896896# ===== Attention backends =====
@@ -905,20 +905,33 @@ def _flash_attention(
905905 key : torch .Tensor ,
906906 value : torch .Tensor ,
907907 dropout_p : float = 0.0 ,
908- scale : Optional [float ] = None ,
909908 is_causal : bool = False ,
909+ scale : Optional [float ] = None ,
910910 return_lse : bool = False ,
911911) -> torch .Tensor :
912- out = flash_attn_func (
913- q = query ,
914- k = key ,
915- v = value ,
916- dropout_p = dropout_p ,
917- softmax_scale = scale ,
918- causal = is_causal ,
919- return_attn_probs = return_lse ,
920- )
921- return out
912+ parallel_config = _AttentionBackendRegistry ._parallel_config
913+
914+ lse = None
915+ if parallel_config is None :
916+ out = flash_attn_func (
917+ q = query ,
918+ k = key ,
919+ v = value ,
920+ dropout_p = dropout_p ,
921+ softmax_scale = scale ,
922+ causal = is_causal ,
923+ return_attn_probs = return_lse ,
924+ )
925+ if return_lse :
926+ out , lse , * _ = out
927+ else :
928+ out = _templated_context_parallel_attention (
929+ query , key , value , None , dropout_p , is_causal , scale , False , return_lse , op = _flash_attention_2
930+ )
931+ if return_lse :
932+ out , lse = out
933+
934+ return (out , lse ) if return_lse else out
922935
923936
924937@_AttentionBackendRegistry .register (
0 commit comments