@@ -751,6 +751,11 @@ def forward(
751751 query : torch .Tensor ,
752752 key : torch .Tensor ,
753753 value : torch .Tensor ,
754+ attn_mask : Optional [torch .Tensor ],
755+ dropout_p : float ,
756+ scale : Optional [float ],
757+ is_causal : bool ,
758+ enable_gqa : bool ,
754759 return_lse : bool ,
755760 op : torch .autograd .Function ,
756761 ):
@@ -773,7 +778,7 @@ def forward(
773778 value = kv [key .numel () :].reshape_as (value )
774779 next_rank = (next_rank + 1 ) % world_size
775780
776- out , lse = op .apply (query , key , value , None , 0.0 , None , False , False , True )
781+ out , lse = op .apply (query , key , value , attn_mask , dropout_p , scale , is_causal , enable_gqa , True )
777782
778783 if parallel_config .convert_to_fp32 :
779784 out = out .to (torch .float32 )
@@ -806,6 +811,11 @@ def forward(
806811 query : torch .Tensor ,
807812 key : torch .Tensor ,
808813 value : torch .Tensor ,
814+ attn_mask : Optional [torch .Tensor ],
815+ dropout_p : float ,
816+ scale : Optional [float ],
817+ is_causal : bool ,
818+ enable_gqa : bool ,
809819 return_lse : bool ,
810820 op : torch .autograd .Function ,
811821 ):
@@ -823,7 +833,7 @@ def forward(
823833 query , key , value = (funcol .all_to_all_single (x , None , None , group = group ).wait () for x in (query , key , value ))
824834 query , key , value = (x .flatten (0 , 1 ).permute (1 , 0 , 2 , 3 ).contiguous () for x in (query , key , value ))
825835
826- out = op .apply (query , key , value , None , 0.0 , None , False , False , return_lse )
836+ out = op .apply (query , key , value , attn_mask , dropout_p , scale , is_causal , enable_gqa , return_lse )
827837 if return_lse :
828838 out , lse , * _ = out
829839
@@ -872,9 +882,13 @@ def _templated_context_parallel_attention(
872882 parallel_config = _AttentionBackendRegistry ._parallel_config
873883 # TODO: add support for unified attention with ring/ulysses degree both being > 1
874884 if parallel_config .ring_degree > 1 :
875- return TemplatedRingAttention .apply (query , key , value , return_lse , op )
885+ return TemplatedRingAttention .apply (
886+ query , key , value , attn_mask , dropout_p , scale , is_causal , enable_gqa , return_lse , op
887+ )
876888 elif parallel_config .ulysses_degree > 1 :
877- return TemplatedUlyssesAttention .apply (query , key , value , return_lse , op )
889+ return TemplatedUlyssesAttention .apply (
890+ query , key , value , attn_mask , dropout_p , scale , is_causal , enable_gqa , return_lse , op
891+ )
878892 else :
879893 return op .apply (query , key , value , attn_mask , dropout_p , scale , is_causal , enable_gqa , return_lse )
880894
0 commit comments