1717import  inspect 
1818import  math 
1919from  enum  import  Enum 
20+ from  functools  import  lru_cache 
2021from  typing  import  Any , Callable , Dict , List , Literal , Optional , Tuple , Union 
2122
2223import  torch 
3940from  ..utils .constants  import  DIFFUSERS_ATTN_BACKEND , DIFFUSERS_ATTN_CHECKS 
4041
4142
42- logger  =  get_logger (__name__ )  # pylint: disable=invalid-name 
43- 
4443_REQUIRED_FLASH_VERSION  =  "2.6.3" 
4544_REQUIRED_SAGE_VERSION  =  "2.1.1" 
4645_REQUIRED_FLEX_VERSION  =  "2.5.0" 
7069    flash_attn_3_func  =  None 
7170    flash_attn_3_varlen_func  =  None 
7271
73- if  is_kernels_available ():
74-     from  ..utils .kernels_utils  import  _get_fa3_from_hub 
75- 
76-     flash_attn_interface_hub  =  _get_fa3_from_hub ()
77-     if  flash_attn_interface_hub  is  not None :
78-         flash_attn_3_hub_func  =  flash_attn_interface_hub .flash_attn_func 
79-         flash_attn_3_varlen_hub_func  =  flash_attn_interface_hub .flash_attn_varlen_func 
80-     else :
81-         flash_attn_3_hub_func  =  None 
82-         flash_attn_3_varlen_hub_func  =  None 
83- else :
84-     flash_attn_3_hub_func  =  None 
85-     flash_attn_3_varlen_hub_func  =  None 
86- 
8772
8873if  _CAN_USE_SAGE_ATTN :
8974    from  sageattention  import  (
@@ -148,6 +133,7 @@ def wrap(func):
148133    _custom_op  =  custom_op_no_op 
149134    _register_fake  =  register_fake_no_op 
150135
136+ logger  =  get_logger (__name__ )  # pylint: disable=invalid-name 
151137
152138# TODO(aryan): Add support for the following: 
153139# - Sage Attention++ 
@@ -169,7 +155,7 @@ class AttentionBackendName(str, Enum):
169155    _FLASH_3  =  "_flash_3" 
170156    _FLASH_VARLEN_3  =  "_flash_varlen_3" 
171157    _FLASH_3_HUB  =  "_flash_3_hub" 
172-     _FLASH_VARLEN_3_HUB  =  "_flash_varlen_3_hub"   # not supported yet. 
158+     #  _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub"  # not supported yet.
173159
174160    # PyTorch native 
175161    FLEX  =  "flex" 
@@ -224,6 +210,22 @@ def list_backends(cls):
224210        return  list (cls ._backends .keys ())
225211
226212
213+ @lru_cache (maxsize = None ) 
214+ def  _load_fa3_hub ():
215+     from  ..utils .kernels_utils  import  _get_fa3_from_hub 
216+ 
217+     fa3_hub  =  _get_fa3_from_hub ()  # won't re-download if already present 
218+     if  fa3_hub  is  None :
219+         raise  RuntimeError (
220+             "Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform." 
221+         )
222+     return  fa3_hub 
223+ 
224+ 
225+ def  flash_attn_3_hub_func (* args , ** kwargs ):
226+     return  _load_fa3_hub ().flash_attn_func (* args , ** kwargs )
227+ 
228+ 
227229@contextlib .contextmanager  
228230def  attention_backend (backend : Union [str , AttentionBackendName ] =  AttentionBackendName .NATIVE ):
229231    """ 
@@ -374,12 +376,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
374376            raise  RuntimeError (
375377                f"Flash Attention 3 Hub backend '{ backend .value }  
376378            )
377-         if  flash_attn_3_hub_func  is  None :
378-             raise  RuntimeError (
379-                 "`flash_attn_3_hub_func` wasn't available. Please double if `kernels` was able to successfully pull the FA3 kernel from kernels-community/vllm-flash-attn3." 
380-             )
381-     elif  backend  in  [AttentionBackendName ._FLASH_VARLEN_3_HUB ]:
382-         raise  NotImplementedError 
383379
384380    elif  backend  in  [
385381        AttentionBackendName .SAGE ,
@@ -544,7 +540,7 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
544540    return  torch .empty_like (query ), query .new_empty (lse_shape )
545541
546542
547- @_custom_op ("vllm_flash_attn3::_flash_attn_forward " , mutates_args = (), device_types = "cuda" ) 
543+ @_custom_op ("vllm_flash_attn3::flash_attn " , mutates_args = (), device_types = "cuda" ) 
548544def  _wrapped_flash_attn_3_hub (
549545    query : torch .Tensor , key : torch .Tensor , value : torch .Tensor 
550546) ->  Tuple [torch .Tensor , torch .Tensor ]:
@@ -553,7 +549,7 @@ def _wrapped_flash_attn_3_hub(
553549    return  out , lse 
554550
555551
556- @_register_fake ("vllm_flash_attn3::_flash_attn_forward " ) 
552+ @_register_fake ("vllm_flash_attn3::flash_attn " ) 
557553def  _ (query : torch .Tensor , key : torch .Tensor , value : torch .Tensor ) ->  Tuple [torch .Tensor , torch .Tensor ]:
558554    batch_size , seq_len , num_heads , head_dim  =  query .shape 
559555    lse_shape  =  (batch_size , seq_len , num_heads )
0 commit comments