@@ -1420,13 +1420,17 @@ def _flash_attention(
14201420 query : torch .Tensor ,
14211421 key : torch .Tensor ,
14221422 value : torch .Tensor ,
1423+ attn_mask : Optional [torch .Tensor ] = None ,
14231424 dropout_p : float = 0.0 ,
14241425 is_causal : bool = False ,
14251426 scale : Optional [float ] = None ,
14261427 return_lse : bool = False ,
14271428 _parallel_config : Optional ["ParallelConfig" ] = None ,
14281429) -> torch .Tensor :
14291430 lse = None
1431+ if attn_mask is not None :
1432+ raise ValueError ("`attn_mask` is not supported for flash-attn 2." )
1433+
14301434 if _parallel_config is None :
14311435 out = flash_attn_func (
14321436 q = query ,
@@ -1469,13 +1473,17 @@ def _flash_attention_hub(
14691473 query : torch .Tensor ,
14701474 key : torch .Tensor ,
14711475 value : torch .Tensor ,
1476+ attn_mask : Optional [torch .Tensor ] = None ,
14721477 dropout_p : float = 0.0 ,
14731478 is_causal : bool = False ,
14741479 scale : Optional [float ] = None ,
14751480 return_lse : bool = False ,
14761481 _parallel_config : Optional ["ParallelConfig" ] = None ,
14771482) -> torch .Tensor :
14781483 lse = None
1484+ if attn_mask is not None :
1485+ raise ValueError ("`attn_mask` is not supported for flash-attn 2." )
1486+
14791487 func = _HUB_KERNELS_REGISTRY [AttentionBackendName .FLASH_HUB ].kernel_fn
14801488 out = func (
14811489 q = query ,
@@ -1612,11 +1620,15 @@ def _flash_attention_3(
16121620 query : torch .Tensor ,
16131621 key : torch .Tensor ,
16141622 value : torch .Tensor ,
1623+ attn_mask : Optional [torch .Tensor ] = None ,
16151624 scale : Optional [float ] = None ,
16161625 is_causal : bool = False ,
16171626 return_lse : bool = False ,
16181627 _parallel_config : Optional ["ParallelConfig" ] = None ,
16191628) -> torch .Tensor :
1629+ if attn_mask is not None :
1630+ raise ValueError ("`attn_mask` is not supported for flash-attn 3." )
1631+
16201632 out , lse = _wrapped_flash_attn_3 (
16211633 q = query ,
16221634 k = key ,
@@ -1636,6 +1648,7 @@ def _flash_attention_3_hub(
16361648 query : torch .Tensor ,
16371649 key : torch .Tensor ,
16381650 value : torch .Tensor ,
1651+ attn_mask : Optional [torch .Tensor ] = None ,
16391652 scale : Optional [float ] = None ,
16401653 is_causal : bool = False ,
16411654 window_size : Tuple [int , int ] = (- 1 , - 1 ),
@@ -1646,6 +1659,8 @@ def _flash_attention_3_hub(
16461659) -> torch .Tensor :
16471660 if _parallel_config :
16481661 raise NotImplementedError (f"{ AttentionBackendName ._FLASH_3_HUB .value } is not implemented for parallelism yet." )
1662+ if attn_mask is not None :
1663+ raise ValueError ("`attn_mask` is not supported for flash-attn 3." )
16491664
16501665 func = _HUB_KERNELS_REGISTRY [AttentionBackendName ._FLASH_3_HUB ].kernel_fn
16511666 out = func (
@@ -1785,12 +1800,16 @@ def _aiter_flash_attention(
17851800 query : torch .Tensor ,
17861801 key : torch .Tensor ,
17871802 value : torch .Tensor ,
1803+ attn_mask : Optional [torch .Tensor ] = None ,
17881804 dropout_p : float = 0.0 ,
17891805 is_causal : bool = False ,
17901806 scale : Optional [float ] = None ,
17911807 return_lse : bool = False ,
17921808 _parallel_config : Optional ["ParallelConfig" ] = None ,
17931809) -> torch .Tensor :
1810+ if attn_mask is not None :
1811+ raise ValueError ("`attn_mask` is not supported for aiter attention" )
1812+
17941813 if not return_lse and torch .is_grad_enabled ():
17951814 # aiter requires return_lse=True by assertion when gradients are enabled.
17961815 out , lse , * _ = aiter_flash_attn_func (
@@ -2028,13 +2047,17 @@ def _native_flash_attention(
20282047 query : torch .Tensor ,
20292048 key : torch .Tensor ,
20302049 value : torch .Tensor ,
2050+ attn_mask : Optional [torch .Tensor ] = None ,
20312051 dropout_p : float = 0.0 ,
20322052 is_causal : bool = False ,
20332053 scale : Optional [float ] = None ,
20342054 enable_gqa : bool = False ,
20352055 return_lse : bool = False ,
20362056 _parallel_config : Optional ["ParallelConfig" ] = None ,
20372057) -> torch .Tensor :
2058+ if attn_mask is not None :
2059+ raise ValueError ("`attn_mask` is not supported for aiter attention" )
2060+
20382061 lse = None
20392062 if _parallel_config is None and not return_lse :
20402063 query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
@@ -2113,11 +2136,14 @@ def _native_npu_attention(
21132136 query : torch .Tensor ,
21142137 key : torch .Tensor ,
21152138 value : torch .Tensor ,
2139+ attn_mask : Optional [torch .Tensor ] = None ,
21162140 dropout_p : float = 0.0 ,
21172141 scale : Optional [float ] = None ,
21182142 return_lse : bool = False ,
21192143 _parallel_config : Optional ["ParallelConfig" ] = None ,
21202144) -> torch .Tensor :
2145+ if attn_mask is not None :
2146+ raise ValueError ("`attn_mask` is not supported for NPU attention" )
21212147 if return_lse :
21222148 raise ValueError ("NPU attention backend does not support setting `return_lse=True`." )
21232149 query , key , value = (x .transpose (1 , 2 ).contiguous () for x in (query , key , value ))
@@ -2148,10 +2174,13 @@ def _native_xla_attention(
21482174 query : torch .Tensor ,
21492175 key : torch .Tensor ,
21502176 value : torch .Tensor ,
2177+ attn_mask : Optional [torch .Tensor ] = None ,
21512178 is_causal : bool = False ,
21522179 return_lse : bool = False ,
21532180 _parallel_config : Optional ["ParallelConfig" ] = None ,
21542181) -> torch .Tensor :
2182+ if attn_mask is not None :
2183+ raise ValueError ("`attn_mask` is not supported for XLA attention" )
21552184 if return_lse :
21562185 raise ValueError ("XLA attention backend does not support setting `return_lse=True`." )
21572186 query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
@@ -2175,11 +2204,14 @@ def _sage_attention(
21752204 query : torch .Tensor ,
21762205 key : torch .Tensor ,
21772206 value : torch .Tensor ,
2207+ attn_mask : Optional [torch .Tensor ] = None ,
21782208 is_causal : bool = False ,
21792209 scale : Optional [float ] = None ,
21802210 return_lse : bool = False ,
21812211 _parallel_config : Optional ["ParallelConfig" ] = None ,
21822212) -> torch .Tensor :
2213+ if attn_mask is not None :
2214+ raise ValueError ("`attn_mask` is not supported for sage attention" )
21832215 lse = None
21842216 if _parallel_config is None :
21852217 out = sageattn (
@@ -2223,11 +2255,14 @@ def _sage_attention_hub(
22232255 query : torch .Tensor ,
22242256 key : torch .Tensor ,
22252257 value : torch .Tensor ,
2258+ attn_mask : Optional [torch .Tensor ] = None ,
22262259 is_causal : bool = False ,
22272260 scale : Optional [float ] = None ,
22282261 return_lse : bool = False ,
22292262 _parallel_config : Optional ["ParallelConfig" ] = None ,
22302263) -> torch .Tensor :
2264+ if attn_mask is not None :
2265+ raise ValueError ("`attn_mask` is not supported for sage attention" )
22312266 lse = None
22322267 func = _HUB_KERNELS_REGISTRY [AttentionBackendName .SAGE_HUB ].kernel_fn
22332268 if _parallel_config is None :
@@ -2309,11 +2344,14 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
23092344 query : torch .Tensor ,
23102345 key : torch .Tensor ,
23112346 value : torch .Tensor ,
2347+ attn_mask : Optional [torch .Tensor ] = None ,
23122348 is_causal : bool = False ,
23132349 scale : Optional [float ] = None ,
23142350 return_lse : bool = False ,
23152351 _parallel_config : Optional ["ParallelConfig" ] = None ,
23162352) -> torch .Tensor :
2353+ if attn_mask is not None :
2354+ raise ValueError ("`attn_mask` is not supported for sage attention" )
23172355 return sageattn_qk_int8_pv_fp8_cuda (
23182356 q = query ,
23192357 k = key ,
@@ -2333,11 +2371,14 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
23332371 query : torch .Tensor ,
23342372 key : torch .Tensor ,
23352373 value : torch .Tensor ,
2374+ attn_mask : Optional [torch .Tensor ] = None ,
23362375 is_causal : bool = False ,
23372376 scale : Optional [float ] = None ,
23382377 return_lse : bool = False ,
23392378 _parallel_config : Optional ["ParallelConfig" ] = None ,
23402379) -> torch .Tensor :
2380+ if attn_mask is not None :
2381+ raise ValueError ("`attn_mask` is not supported for sage attention" )
23412382 return sageattn_qk_int8_pv_fp8_cuda_sm90 (
23422383 q = query ,
23432384 k = key ,
@@ -2357,11 +2398,14 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
23572398 query : torch .Tensor ,
23582399 key : torch .Tensor ,
23592400 value : torch .Tensor ,
2401+ attn_mask : Optional [torch .Tensor ] = None ,
23602402 is_causal : bool = False ,
23612403 scale : Optional [float ] = None ,
23622404 return_lse : bool = False ,
23632405 _parallel_config : Optional ["ParallelConfig" ] = None ,
23642406) -> torch .Tensor :
2407+ if attn_mask is not None :
2408+ raise ValueError ("`attn_mask` is not supported for sage attention" )
23652409 return sageattn_qk_int8_pv_fp16_cuda (
23662410 q = query ,
23672411 k = key ,
@@ -2381,11 +2425,14 @@ def _sage_qk_int8_pv_fp16_triton_attention(
23812425 query : torch .Tensor ,
23822426 key : torch .Tensor ,
23832427 value : torch .Tensor ,
2428+ attn_mask : Optional [torch .Tensor ] = None ,
23842429 is_causal : bool = False ,
23852430 scale : Optional [float ] = None ,
23862431 return_lse : bool = False ,
23872432 _parallel_config : Optional ["ParallelConfig" ] = None ,
23882433) -> torch .Tensor :
2434+ if attn_mask is not None :
2435+ raise ValueError ("`attn_mask` is not supported for sage attention" )
23892436 return sageattn_qk_int8_pv_fp16_triton (
23902437 q = query ,
23912438 k = key ,
0 commit comments