1- # Copyright 2024  The HuggingFace Team. All rights reserved. 
1+ # Copyright 2025  The HuggingFace Team. All rights reserved. 
22# 
33# Licensed under the Apache License, Version 2.0 (the "License"); 
44# you may not use this file except in compliance with the License. 
1515import  contextlib 
1616import  functools 
1717import  inspect 
18+ import  math 
1819from  enum  import  Enum 
1920from  typing  import  Any , Callable , Dict , List , Literal , Optional , Tuple , Union 
2021
2122import  torch 
2223
2324from  ..utils  import  (
24-     OptionalDependencyNotAvailable ,
2525    get_logger ,
26+     is_flash_attn_3_available ,
2627    is_flash_attn_available ,
2728    is_flash_attn_version ,
2829    is_sageattention_available ,
2930    is_sageattention_version ,
31+     is_torch_npu_available ,
3032    is_torch_version ,
33+     is_torch_xla_available ,
34+     is_torch_xla_version ,
3135    is_xformers_available ,
3236    is_xformers_version ,
3337)
3438from  ..utils .constants  import  DIFFUSERS_ATTN_BACKEND , DIFFUSERS_ATTN_CHECKS 
3539
3640
37- if  is_flash_attn_available ():
38-     if  is_flash_attn_version ("<" , "2.6.3" ):
39-         raise  OptionalDependencyNotAvailable (
40-             "The `flash-attn` library version is too old. Please update it to at least 2.6.3." 
41-         )
41+ logger  =  get_logger (__name__ )  # pylint: disable=invalid-name 
42+ 
4243
44+ if  is_flash_attn_available () and  is_flash_attn_version (">=" , "2.6.3" ):
4345    from  flash_attn  import  flash_attn_func , flash_attn_varlen_func 
4446else :
47+     logger .warning ("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`." )
4548    flash_attn_func  =  None 
4649    flash_attn_varlen_func  =  None 
4750
4851
49- if  is_sageattention_available ():
50-     if  is_sageattention_version ("<" , "2.1.1" ):
51-         raise  OptionalDependencyNotAvailable (
52-             "The `sageattention` library version is too old. Please update it to at least 2.1.1." 
53-         )
52+ if  is_flash_attn_3_available ():
53+     from  flash_attn_interface  import  flash_attn_func  as  flash_attn_3_func 
54+     from  flash_attn_interface  import  flash_attn_varlen_func  as  flash_attn_3_varlen_func 
55+ else :
56+     flash_attn_3_func  =  None 
57+     flash_attn_3_varlen_func  =  None 
5458
59+ 
60+ if  is_sageattention_available () and  is_sageattention_version (">=" , "2.1.1" ):
5561    from  sageattention  import  (
5662        sageattn ,
5763        sageattn_qk_int8_pv_fp8_cuda ,
6167        sageattn_varlen ,
6268    )
6369else :
70+     logger .warning (
71+         "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`." 
72+     )
6473    sageattn  =  None 
6574    sageattn_qk_int8_pv_fp16_cuda  =  None 
6675    sageattn_qk_int8_pv_fp16_triton  =  None 
7685    import  torch .nn .attention .flex_attention  as  flex_attention 
7786
7887
79- if  is_xformers_available ():
80-     if  is_xformers_version ("<" , "0.0.29" ):
81-         raise  OptionalDependencyNotAvailable (
82-             "The `xformers` library version is too old. Please update it to at least 0.0.29." 
83-         )
88+ if  is_torch_npu_available ():
89+     from  torch_npu  import  npu_fusion_attention 
90+ else :
91+     npu_fusion_attention  =  None 
8492
93+ 
94+ if  is_torch_xla_available () and  is_torch_xla_version (">" , "2.2" ):
95+     from  torch_xla .experimental .custom_kernel  import  flash_attention  as  xla_flash_attention 
96+ else :
97+     xla_flash_attention  =  None 
98+ 
99+ 
100+ if  is_xformers_available () and  is_xformers_version (">=" , "0.0.29" ):
85101    import  xformers .ops  as  xops 
86102else :
103+     logger .warning ("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`." )
87104    xops  =  None 
88105
89106
90- logger  =  get_logger (__name__ )  # pylint: disable=invalid-name 
91- 
92107_SAGE_ATTENTION_PV_ACCUM_DTYPE  =  Literal ["fp32" , "fp32+fp32" ]
93108_SAGE_ATTENTION_QK_QUANT_GRAN  =  Literal ["per_thread" , "per_warp" ]
94109_SAGE_ATTENTION_QUANTIZATION_BACKEND  =  Literal ["cuda" , "triton" ]
@@ -100,6 +115,8 @@ class AttentionBackendName(str, Enum):
100115    # `flash-attn` 
101116    FLASH  =  "flash" 
102117    FLASH_VARLEN  =  "flash_varlen" 
118+     _FLASH_3  =  "_flash_3" 
119+     _FLASH_VARLEN_3  =  "_flash_varlen_3" 
103120
104121    # PyTorch native 
105122    FLEX  =  "flex" 
@@ -108,6 +125,8 @@ class AttentionBackendName(str, Enum):
108125    _NATIVE_EFFICIENT  =  "_native_efficient" 
109126    _NATIVE_FLASH  =  "_native_flash" 
110127    _NATIVE_MATH  =  "_native_math" 
128+     _NATIVE_NPU  =  "_native_npu" 
129+     _NATIVE_XLA  =  "_native_xla" 
111130
112131    # `sageattention` 
113132    SAGE  =  "sage" 
@@ -274,7 +293,7 @@ def _check_shape(
274293# ===== Helper functions ===== 
275294
276295
277- @functools .lru_cache (maxsize = 1 ) 
296+ @functools .lru_cache (maxsize = 8 ) 
278297def  _prepare_for_flash_attn_or_sage_varlen (
279298    batch_size : int ,
280299    seq_len_q : int ,
@@ -371,12 +390,7 @@ def _flash_attention(
371390    alibi_slopes : Optional [torch .Tensor ] =  None ,
372391    deterministic : bool  =  False ,
373392    return_attn_probs : bool  =  False ,
374-     attn_mask : Optional [torch .Tensor ] =  None ,
375-     enable_gqa : bool  =  False ,
376393) ->  torch .Tensor :
377-     if  enable_gqa :
378-         raise  NotImplementedError ("GQA is not yet supported." )
379- 
380394    query , key , value  =  (x .permute (0 , 2 , 1 , 3 ) for  x  in  (query , key , value ))
381395    out  =  flash_attn_func (
382396        q = query ,
@@ -392,7 +406,6 @@ def _flash_attention(
392406        return_attn_probs = return_attn_probs ,
393407    )
394408    out  =  out .permute (0 , 2 , 1 , 3 )
395- 
396409    return  out 
397410
398411
@@ -417,17 +430,13 @@ def _flash_varlen_attention(
417430    deterministic : bool  =  False ,
418431    return_attn_probs : bool  =  False ,
419432    attn_mask : Optional [torch .Tensor ] =  None ,
420-     enable_gqa : bool  =  False ,
421433) ->  torch .Tensor :
422434    batch_size , _ , seq_len_q , _  =  query .shape 
423435    _ , _ , seq_len_kv , _  =  key .shape 
424436
425437    if  attn_mask  is  not None :
426438        attn_mask  =  _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
427439
428-     if  enable_gqa :
429-         raise  NotImplementedError ("GQA is not yet supported." )
430- 
431440    if  any (x  is  None  for  x  in  (cu_seqlens_q , cu_seqlens_k , max_seqlen_q , max_seqlen_k )):
432441        (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) =  (
433442            _prepare_for_flash_attn_or_sage_varlen (
@@ -473,6 +482,121 @@ def _flash_varlen_attention(
473482    return  out 
474483
475484
485+ @_AttentionBackendRegistry .register ( 
486+     AttentionBackendName ._FLASH_3 , 
487+     constraints = [_check_attn_mask_is_none , _check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ], 
488+ ) 
489+ def  _flash_attention_3 (
490+     query : torch .Tensor ,
491+     key : torch .Tensor ,
492+     value : torch .Tensor ,
493+     scale : Optional [float ] =  None ,
494+     is_causal : bool  =  False ,
495+     window_size : Tuple [int , int ] =  (- 1 , - 1 ),
496+     softcap : float  =  0.0 ,
497+     deterministic : bool  =  False ,
498+     return_attn_probs : bool  =  False ,
499+ ) ->  torch .Tensor :
500+     query , key , value  =  (x .permute (0 , 2 , 1 , 3 ) for  x  in  (query , key , value ))
501+     out , lse , * _  =  flash_attn_3_func (
502+         q = query ,
503+         k = key ,
504+         v = value ,
505+         softmax_scale = scale ,
506+         causal = is_causal ,
507+         qv = None ,
508+         q_descale = None ,
509+         k_descale = None ,
510+         v_descale = None ,
511+         window_size = window_size ,
512+         attention_chunk = 0 ,
513+         softcap = softcap ,
514+         num_splits = 1 ,
515+         pack_gqa = None ,
516+         deterministic = deterministic ,
517+         sm_margin = 0 ,
518+     )
519+     out  =  out .permute (0 , 2 , 1 , 3 )
520+     return  (out , lse ) if  return_attn_probs  else  out 
521+ 
522+ 
523+ @_AttentionBackendRegistry .register ( 
524+     AttentionBackendName ._FLASH_VARLEN_3 , 
525+     constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ], 
526+ ) 
527+ def  _flash_varlen_attention_3 (
528+     query : torch .Tensor ,
529+     key : torch .Tensor ,
530+     value : torch .Tensor ,
531+     cu_seqlens_q : Optional [torch .Tensor ] =  None ,
532+     cu_seqlens_k : Optional [torch .Tensor ] =  None ,
533+     max_seqlen_q : Optional [int ] =  None ,
534+     max_seqlen_k : Optional [int ] =  None ,
535+     scale : Optional [float ] =  None ,
536+     is_causal : bool  =  False ,
537+     window_size : Tuple [int , int ] =  (- 1 , - 1 ),
538+     softcap : float  =  0.0 ,
539+     deterministic : bool  =  False ,
540+     return_attn_probs : bool  =  False ,
541+     attn_mask : Optional [torch .Tensor ] =  None ,
542+ ) ->  torch .Tensor :
543+     batch_size , _ , seq_len_q , _  =  query .shape 
544+     _ , _ , seq_len_kv , _  =  key .shape 
545+ 
546+     if  attn_mask  is  not None :
547+         attn_mask  =  _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
548+ 
549+     if  any (x  is  None  for  x  in  (cu_seqlens_q , cu_seqlens_k , max_seqlen_q , max_seqlen_k )):
550+         (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) =  (
551+             _prepare_for_flash_attn_or_sage_varlen (
552+                 batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device 
553+             )
554+         )
555+     else :
556+         seqlens_k  =  torch .full ((batch_size ,), max_seqlen_k , dtype = torch .int32 , device = query .device )
557+         cu_seqlens_q  =  cu_seqlens_q .to (dtype = torch .int32 , device = query .device )
558+         cu_seqlens_k  =  cu_seqlens_k .to (dtype = torch .int32 , device = query .device )
559+ 
560+     query , key , value  =  (x .permute (0 , 2 , 1 , 3 ) for  x  in  (query , key , value ))
561+ 
562+     key_valid , value_valid  =  [], []
563+     for  b  in  range (batch_size ):
564+         valid_len  =  seqlens_k [b ]
565+         key_valid .append (key [b , :valid_len ])
566+         value_valid .append (value [b , :valid_len ])
567+ 
568+     query_packed  =  query .flatten (0 , 1 )
569+     key_packed  =  torch .cat (key_valid , dim = 0 )
570+     value_packed  =  torch .cat (value_valid , dim = 0 )
571+ 
572+     out , lse , * _  =  flash_attn_3_varlen_func (
573+         q = query_packed ,
574+         k = key_packed ,
575+         v = value_packed ,
576+         cu_seqlens_q = cu_seqlens_q ,
577+         cu_seqlens_k = cu_seqlens_k ,
578+         max_seqlen_q = max_seqlen_q ,
579+         max_seqlen_k = max_seqlen_k ,
580+         seqused_q = None ,
581+         seqused_k = None ,
582+         softmax_scale = scale ,
583+         causal = is_causal ,
584+         qv = None ,
585+         q_descale = None ,
586+         k_descale = None ,
587+         v_descale = None ,
588+         window_size = window_size ,
589+         softcap = softcap ,
590+         num_splits = 1 ,
591+         pack_gqa = None ,
592+         deterministic = deterministic ,
593+         sm_margin = 0 ,
594+     )
595+     out  =  out .unflatten (0 , (batch_size , - 1 )).permute (0 , 2 , 1 , 3 )
596+ 
597+     return  (out , lse ) if  return_attn_probs  else  out 
598+ 
599+ 
476600@_AttentionBackendRegistry .register ( 
477601    AttentionBackendName .FLEX , 
478602    constraints = [_check_attn_mask_or_causal , _check_device , _check_shape ], 
@@ -668,6 +792,53 @@ def _native_math_attention(
668792        )
669793
670794
795+ @_AttentionBackendRegistry .register ( 
796+     AttentionBackendName ._NATIVE_NPU , 
797+     constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ], 
798+ ) 
799+ def  _native_npu_attention (
800+     query : torch .Tensor ,
801+     key : torch .Tensor ,
802+     value : torch .Tensor ,
803+     dropout_p : float  =  0.0 ,
804+     scale : Optional [float ] =  None ,
805+ ) ->  torch .Tensor :
806+     return  npu_fusion_attention (
807+         query ,
808+         key ,
809+         value ,
810+         query .size (1 ),  # num_heads 
811+         input_layout = "BNSD" ,
812+         pse = None ,
813+         scale = 1.0  /  math .sqrt (query .shape [- 1 ]) if  scale  is  None  else  scale ,
814+         pre_tockens = 65536 ,
815+         next_tokens = 65536 ,
816+         keep_prob = 1.0  -  dropout_p ,
817+         sync = False ,
818+         inner_precise = 0 ,
819+     )[0 ]
820+ 
821+ 
822+ # Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853 
823+ @_AttentionBackendRegistry .register ( 
824+     AttentionBackendName ._NATIVE_XLA , 
825+     constraints = [_check_device , _check_shape ], 
826+ ) 
827+ def  _native_xla_attention (
828+     query : torch .Tensor ,
829+     key : torch .Tensor ,
830+     value : torch .Tensor ,
831+     is_causal : bool  =  False ,
832+ ) ->  torch .Tensor :
833+     query  =  query  /  math .sqrt (query .shape [- 1 ])
834+     return  xla_flash_attention (
835+         q = query ,
836+         k = key ,
837+         v = value ,
838+         causal = is_causal ,
839+     )
840+ 
841+ 
671842@_AttentionBackendRegistry .register ( 
672843    AttentionBackendName .SAGE , 
673844    constraints = [_check_device_cuda , _check_qkv_dtype_bf16_or_fp16 , _check_shape ], 
@@ -707,17 +878,13 @@ def _sage_varlen_attention(
707878    scale : Optional [float ] =  None ,
708879    smooth_k : bool  =  True ,
709880    attn_mask : Optional [torch .Tensor ] =  None ,
710-     enable_gqa : bool  =  False ,
711881) ->  torch .Tensor :
712882    batch_size , _ , seq_len_q , _  =  query .shape 
713883    _ , _ , seq_len_kv , _  =  key .shape 
714884
715885    if  attn_mask  is  not None :
716886        attn_mask  =  _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
717887
718-     if  enable_gqa :
719-         raise  NotImplementedError ("GQA is not yet supported." )
720- 
721888    if  any (x  is  None  for  x  in  (cu_seqlens_q , cu_seqlens_k , max_seqlen_q , max_seqlen_k )):
722889        (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) =  (
723890            _prepare_for_flash_attn_or_sage_varlen (
0 commit comments