@@ -278,21 +278,33 @@ def __init__(
278278 # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
279279 # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
280280 # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
281- # If torch_xla is available with the correct version, we use pallas flash attention kernel to improve
282- # the performance.
283281 if processor is None :
284- if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk :
285- if (
286- is_torch_xla_available
287- and is_torch_xla_version ('>' , '2.2' )
288- and (not is_spmd () or is_torch_xla_version ('>' , '2.3' ))
289- ):
290- processor = XLAFlashAttnProcessor2_0 ()
291- else :
292- processor = AttnProcessor2_0 ()
293- else :
294- processor = AttnProcessor ()
282+ processor = (
283+ AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk else AttnProcessor ()
284+ )
285+ self .set_processor (processor )
286+
287+ def set_use_xla_flash_attention (self , use_xla_flash_attention : bool , partition_spec : Optional [Tuple [Optional [str ], ...]] = None ) -> None :
288+ r"""
289+ Set whether to use xla flash attention from `torch_xla` or not.
295290
291+ Args:
292+ use_xla_flash_attention (`bool`):
293+ Whether to use pallas flash attention kernel from `torch_xla` or not.
294+ partition_spec (`Tuple[]`, *optional*):
295+ Specify the partition specification if using SPMD. Otherwise None.
296+ """
297+ if (
298+ use_xla_flash_attention
299+ and is_torch_xla_available
300+ and is_torch_xla_version ('>' , '2.2' )
301+ and (not is_spmd () or is_torch_xla_version ('>' , '2.3' ))
302+ ):
303+ processor = XLAFlashAttnProcessor2_0 (partition_spec )
304+ else :
305+ processor = (
306+ AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk else AttnProcessor ()
307+ )
296308 self .set_processor (processor )
297309
298310 def set_use_npu_flash_attention (self , use_npu_flash_attention : bool ) -> None :
@@ -2772,16 +2784,17 @@ def __call__(
27722784
27732785class XLAFlashAttnProcessor2_0 :
27742786 r"""
2775- Processor for implementing scaled dot-product attention (enabled by default if you're using torch_xla) .
2787+ Processor for implementing scaled dot-product attention with pallas flash attention kernel if using ` torch_xla` .
27762788 """
27772789
2778- def __init__ (self ):
2790+ def __init__ (self , partition_spec : Optional [ Tuple [ Optional [ str ], ...]] = None ):
27792791 if not hasattr (F , "scaled_dot_product_attention" ):
27802792 raise ImportError ("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
27812793 if is_torch_xla_version ("<" , "2.3" ):
27822794 raise ImportError ("XLA flash attention requires torch_xla version >= 2.3." )
27832795 if is_spmd () and is_torch_xla_version ("<" , "2.4" ):
27842796 raise ImportError ("SPMD support for XLA flash attention needs torch_xla version >= 2.4." )
2797+ self .partition_spec = partition_spec
27852798
27862799 def __call__ (
27872800 self ,
@@ -2854,7 +2867,7 @@ def __call__(
28542867 # Apply attention mask to key
28552868 key = key + attention_mask
28562869 query /= math .sqrt (query .shape [3 ])
2857- partition_spec = ( "data" , None , None , None ) if is_spmd () else None
2870+ partition_spec = self . partition_spec if is_spmd () else None
28582871 hidden_states = flash_attention (query , key , value , causal = False , partition_spec = partition_spec )
28592872 else :
28602873 hidden_states = F .scaled_dot_product_attention (
@@ -5201,6 +5214,7 @@ def __init__(self):
52015214 FusedCogVideoXAttnProcessor2_0 ,
52025215 XFormersAttnAddedKVProcessor ,
52035216 XFormersAttnProcessor ,
5217+ XLAFlashAttnProcessor2_0 ,
52045218 AttnProcessorNPU ,
52055219 AttnProcessor2_0 ,
52065220 MochiVaeAttnProcessor2_0 ,
0 commit comments