1818import math
1919from dataclasses import dataclass
2020from enum import Enum
21- from typing import (
22- TYPE_CHECKING ,
23- Any ,
24- Callable ,
25- Dict ,
26- List ,
27- Literal ,
28- Optional ,
29- Tuple ,
30- Union ,
31- )
21+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Literal , Optional , Tuple , Union
3222
3323import torch
3424
7868
7969if _CAN_USE_FLASH_ATTN :
8070 from flash_attn import flash_attn_func , flash_attn_varlen_func
81- from flash_attn .flash_attn_interface import (
82- _wrapped_flash_attn_backward ,
83- _wrapped_flash_attn_forward ,
84- )
71+ from flash_attn .flash_attn_interface import _wrapped_flash_attn_backward , _wrapped_flash_attn_forward
8572else :
8673 flash_attn_func = None
8774 flash_attn_varlen_func = None
9077
9178
9279if _CAN_USE_FLASH_ATTN_3 :
93- from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward
9480 from flash_attn_interface import flash_attn_func as flash_attn_3_func
9581 from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
9682else :
9783 flash_attn_3_func = None
9884 flash_attn_3_varlen_func = None
99- flash_attn_3_forward = None
10085
10186if _CAN_USE_AITER_ATTN :
10287 from aiter import flash_attn_func as aiter_flash_attn_func
135120
136121
137122if _CAN_USE_XLA_ATTN :
138- from torch_xla .experimental .custom_kernel import (
139- flash_attention as xla_flash_attention ,
140- )
123+ from torch_xla .experimental .custom_kernel import flash_attention as xla_flash_attention
141124else :
142125 xla_flash_attention = None
143126
@@ -280,17 +263,13 @@ class _HubKernelConfig:
280263_HUB_KERNELS_REGISTRY : Dict ["AttentionBackendName" , _HubKernelConfig ] = {
281264 # TODO: temporary revision for now. Remove when merged upstream into `main`.
282265 AttentionBackendName ._FLASH_3_HUB : _HubKernelConfig (
283- repo_id = "kernels-community/flash-attn3" ,
284- function_attr = "flash_attn_func" ,
285- revision = "fake-ops-return-probs" ,
266+ repo_id = "kernels-community/flash-attn3" , function_attr = "flash_attn_func" , revision = "fake-ops-return-probs"
286267 )
287268}
288269
289270
290271@contextlib .contextmanager
291- def attention_backend (
292- backend : Union [str , AttentionBackendName ] = AttentionBackendName .NATIVE ,
293- ):
272+ def attention_backend (backend : Union [str , AttentionBackendName ] = AttentionBackendName .NATIVE ):
294273 """
295274 Context manager to set the active attention backend.
296275 """
@@ -435,10 +414,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
435414 f"Flash Attention backend '{ backend .value } ' is not usable because of missing package or the version is too old. Please install `flash-attn>={ _REQUIRED_FLASH_VERSION } `."
436415 )
437416
438- elif backend in [
439- AttentionBackendName ._FLASH_3 ,
440- AttentionBackendName ._FLASH_VARLEN_3 ,
441- ]:
417+ elif backend in [AttentionBackendName ._FLASH_3 , AttentionBackendName ._FLASH_VARLEN_3 ]:
442418 if not _CAN_USE_FLASH_ATTN_3 :
443419 raise RuntimeError (
444420 f"Flash Attention 3 backend '{ backend .value } ' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
@@ -510,11 +486,7 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
510486 cu_seqlens_k [1 :] = torch .cumsum (seqlens_k , dim = 0 )
511487 max_seqlen_q = seqlens_q .max ().item ()
512488 max_seqlen_k = seqlens_k .max ().item ()
513- return (
514- (seqlens_q , seqlens_k ),
515- (cu_seqlens_q , cu_seqlens_k ),
516- (max_seqlen_q , max_seqlen_k ),
517- )
489+ return (seqlens_q , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k )
518490
519491
520492def _prepare_for_flash_attn_or_sage_varlen_with_mask (
@@ -531,11 +503,7 @@ def _prepare_for_flash_attn_or_sage_varlen_with_mask(
531503 cu_seqlens_k [1 :] = torch .cumsum (seqlens_k , dim = 0 )
532504 max_seqlen_q = seqlens_q .max ().item ()
533505 max_seqlen_k = seqlens_k .max ().item ()
534- return (
535- (seqlens_q , seqlens_k ),
536- (cu_seqlens_q , cu_seqlens_k ),
537- (max_seqlen_q , max_seqlen_k ),
538- )
506+ return (seqlens_q , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k )
539507
540508
541509def _prepare_for_flash_attn_or_sage_varlen (
@@ -653,42 +621,22 @@ def _wrapped_flash_attn_3(
653621) -> Tuple [torch .Tensor , torch .Tensor ]:
654622 # Hardcoded for now because pytorch does not support tuple/int type hints
655623 window_size = (- 1 , - 1 )
656- max_seqlen_q = q .shape [2 ]
657- max_seqlen_k = k .shape [2 ]
658-
659- out , lse , * _ = flash_attn_3_forward (
624+ out , lse , * _ = flash_attn_3_func (
660625 q = q ,
661626 k = k ,
662627 v = v ,
663- k_new = None ,
664- v_new = None ,
628+ softmax_scale = softmax_scale ,
629+ causal = causal ,
665630 qv = qv ,
666- out = None ,
667- cu_seqlens_q = None ,
668- cu_seqlens_k = None ,
669- cu_seqlens_k_new = None ,
670- seqused_q = None ,
671- seqused_k = None ,
672- max_seqlen_q = max_seqlen_q ,
673- max_seqlen_k = max_seqlen_k ,
674- page_table = None ,
675- kv_batch_idx = None ,
676- leftpad_k = None ,
677- rotary_cos = None ,
678- rotary_sin = None ,
679- seqlens_rotary = None ,
680631 q_descale = q_descale ,
681632 k_descale = k_descale ,
682633 v_descale = v_descale ,
683- softmax_scale = softmax_scale ,
684- causal = causal ,
685634 window_size = window_size ,
686635 attention_chunk = attention_chunk ,
687636 softcap = softcap ,
688- rotary_interleaved = True ,
689- scheduler_metadata = None ,
690637 num_splits = num_splits ,
691638 pack_gqa = pack_gqa ,
639+ deterministic = deterministic ,
692640 sm_margin = sm_margin ,
693641 )
694642 lse = lse .permute (0 , 2 , 1 )
@@ -794,10 +742,7 @@ def _native_attention_backward_op(
794742
795743 grad_out_t = grad_out .permute (0 , 2 , 1 , 3 )
796744 grad_query_t , grad_key_t , grad_value_t = torch .autograd .grad (
797- outputs = out ,
798- inputs = [query_t , key_t , value_t ],
799- grad_outputs = grad_out_t ,
800- retain_graph = False ,
745+ outputs = out , inputs = [query_t , key_t , value_t ], grad_outputs = grad_out_t , retain_graph = False
801746 )
802747
803748 grad_query = grad_query_t .permute (0 , 2 , 1 , 3 )
@@ -836,26 +781,18 @@ def _cudnn_attention_forward_op(
836781 value = value .transpose (1 , 2 ).contiguous ()
837782 tensors_to_save += (query , key , value )
838783
839- (
840- out ,
841- lse ,
842- cum_seq_q ,
843- cum_seq_k ,
844- max_q ,
845- max_k ,
846- philox_seed ,
847- philox_offset ,
848- debug_attn_mask ,
849- ) = torch .ops .aten ._scaled_dot_product_cudnn_attention (
850- query = query ,
851- key = key ,
852- value = value ,
853- attn_bias = attn_mask ,
854- compute_log_sumexp = return_lse ,
855- dropout_p = dropout_p ,
856- is_causal = is_causal ,
857- return_debug_mask = False ,
858- scale = scale ,
784+ out , lse , cum_seq_q , cum_seq_k , max_q , max_k , philox_seed , philox_offset , debug_attn_mask = (
785+ torch .ops .aten ._scaled_dot_product_cudnn_attention (
786+ query = query ,
787+ key = key ,
788+ value = value ,
789+ attn_bias = attn_mask ,
790+ compute_log_sumexp = return_lse ,
791+ dropout_p = dropout_p ,
792+ is_causal = is_causal ,
793+ return_debug_mask = False ,
794+ scale = scale ,
795+ )
859796 )
860797
861798 tensors_to_save += (out , lse , cum_seq_q , cum_seq_k , philox_seed , philox_offset )
@@ -982,11 +919,7 @@ def _flash_attention_backward_op(
982919 ** kwargs ,
983920):
984921 query , key , value , out , lse , rng_state = ctx .saved_tensors
985- grad_query , grad_key , grad_value = (
986- torch .empty_like (query ),
987- torch .empty_like (key ),
988- torch .empty_like (value ),
989- )
922+ grad_query , grad_key , grad_value = torch .empty_like (query ), torch .empty_like (key ), torch .empty_like (value )
990923
991924 lse_d = _wrapped_flash_attn_backward ( # noqa: F841
992925 grad_out ,
@@ -1210,19 +1143,7 @@ def backward(
12101143
12111144 grad_query , grad_key , grad_value = (x .to (grad_out .dtype ) for x in (grad_query , grad_key , grad_value ))
12121145
1213- return (
1214- grad_query ,
1215- grad_key ,
1216- grad_value ,
1217- None ,
1218- None ,
1219- None ,
1220- None ,
1221- None ,
1222- None ,
1223- None ,
1224- None ,
1225- )
1146+ return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None
12261147
12271148
12281149class TemplatedUlyssesAttention (torch .autograd .Function ):
@@ -1317,19 +1238,7 @@ def backward(
13171238 x .flatten (0 , 1 ).permute (1 , 2 , 0 , 3 ).contiguous () for x in (grad_query , grad_key , grad_value )
13181239 )
13191240
1320- return (
1321- grad_query ,
1322- grad_key ,
1323- grad_value ,
1324- None ,
1325- None ,
1326- None ,
1327- None ,
1328- None ,
1329- None ,
1330- None ,
1331- None ,
1332- )
1241+ return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None
13331242
13341243
13351244def _templated_context_parallel_attention (
@@ -1677,12 +1586,7 @@ def _native_flex_attention(
16771586 block_mask = attn_mask
16781587 elif is_causal :
16791588 block_mask = flex_attention .create_block_mask (
1680- _flex_attention_causal_mask_mod ,
1681- batch_size ,
1682- num_heads ,
1683- seq_len_q ,
1684- seq_len_kv ,
1685- query .device ,
1589+ _flex_attention_causal_mask_mod , batch_size , num_heads , seq_len_q , seq_len_kv , query .device
16861590 )
16871591 elif torch .is_tensor (attn_mask ):
16881592 if attn_mask .ndim == 2 :
@@ -1702,7 +1606,6 @@ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
17021606
17031607 def score_mod (score , batch_idx , head_idx , q_idx , kv_idx ):
17041608 return score + attn_mask [batch_idx , head_idx , q_idx , kv_idx ]
1705-
17061609 else :
17071610 raise ValueError ("Attention mask must be either None, a BlockMask, or a 2D/4D tensor." )
17081611
0 commit comments