2121
2222from ..image_processor import IPAdapterMaskProcessor
2323from ..utils import deprecate , is_torch_xla_available , logging
24- from ..utils .import_utils import is_torch_npu_available , is_xformers_available
24+ from ..utils .import_utils import is_torch_npu_available , is_xformers_available , is_torch_xla_version
2525from ..utils .torch_utils import is_torch_version , maybe_allow_in_graph
2626
2727
3737 xformers = None
3838
3939if is_torch_xla_available ():
40- from torch_xla .experimental .custom_kernel import flash_attention
41-
40+ # flash attention pallas kernel is introduced in the torch_xla 2.3 release.
41+ if is_torch_xla_version (">" , "2.2" ):
42+ from torch_xla .runtime import is_spmd
43+ from torch_xla .experimental .custom_kernel import flash_attention
4244 XLA_AVAILABLE = True
4345else :
4446 XLA_AVAILABLE = False
@@ -276,16 +278,21 @@ def __init__(
276278 # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
277279 # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
278280 # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
279- # If torch_xla is available, we use pallas flash attention kernel to improve the performance.
281+ # If torch_xla is available with the correct version, we use pallas flash attention kernel to improve
282+ # the performance.
280283 if processor is None :
281284 if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk :
282- if is_torch_xla_available :
285+ if (
286+ is_torch_xla_available
287+ and is_torch_xla_version ('>' , '2.2' )
288+ and (not is_spmd () or is_torch_xla_version ('>' , '2.3' ))
289+ ):
283290 processor = XLAFlashAttnProcessor2_0 ()
284291 else :
285292 processor = AttnProcessor2_0 ()
286293 else :
287294 processor = AttnProcessor ()
288-
295+
289296 self .set_processor (processor )
290297
291298 def set_use_npu_flash_attention (self , use_npu_flash_attention : bool ) -> None :
@@ -2771,8 +2778,10 @@ class XLAFlashAttnProcessor2_0:
27712778 def __init__ (self ):
27722779 if not hasattr (F , "scaled_dot_product_attention" ):
27732780 raise ImportError ("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
2774- if not is_torch_xla_available :
2775- raise ImportError ("XLAFlashAttnProcessor2_0 required torch_xla package." )
2781+ if is_torch_xla_version ("<" , "2.3" ):
2782+ raise ImportError ("XLA flash attention requires torch_xla version >= 2.3." )
2783+ if is_spmd () and is_torch_xla_version ("<" , "2.4" ):
2784+ raise ImportError ("SPMD support for XLA flash attention needs torch_xla version >= 2.4." )
27762785
27772786 def __call__ (
27782787 self ,
@@ -2784,10 +2793,6 @@ def __call__(
27842793 * args ,
27852794 ** kwargs ,
27862795 ) -> torch .Tensor :
2787- if len (args ) > 0 or kwargs .get ("scale" , None ) is not None :
2788- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2789- deprecate ("scale" , "1.0.0" , deprecation_message )
2790-
27912796 residual = hidden_states
27922797 if attn .spatial_norm is not None :
27932798 hidden_states = attn .spatial_norm (hidden_states , temb )
@@ -2836,7 +2841,7 @@ def __call__(
28362841
28372842 # the output of sdp = (batch, num_heads, seq_len, head_dim)
28382843 # TODO: add support for attn.scale when we move to Torch 2.1
2839- if XLA_AVAILABLE and all (tensor .shape [2 ] >= 4096 for tensor in [query , key , value ]):
2844+ if all (tensor .shape [2 ] >= 4096 for tensor in [query , key , value ]):
28402845 if attention_mask is not None :
28412846 attention_mask = attention_mask .view (batch_size , 1 , 1 , attention_mask .shape [- 1 ])
28422847 # Convert mask to float and replace 0s with -inf and 1s with 0
@@ -2849,7 +2854,8 @@ def __call__(
28492854 # Apply attention mask to key
28502855 key = key + attention_mask
28512856 query /= math .sqrt (query .shape [3 ])
2852- hidden_states = flash_attention (query , key , value , causal = False , partition_spec = ("data" , None , None , None ))
2857+ partition_spec = ("data" , None , None , None ) if is_spmd () else None
2858+ hidden_states = flash_attention (query , key , value , causal = False , partition_spec = partition_spec )
28532859 else :
28542860 hidden_states = F .scaled_dot_product_attention (
28552861 query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
0 commit comments