@@ -294,13 +294,15 @@ def set_use_xla_flash_attention(self, use_xla_flash_attention: bool, partition_s
294294 partition_spec (`Tuple[]`, *optional*):
295295 Specify the partition specification if using SPMD. Otherwise None.
296296 """
297- if (
298- use_xla_flash_attention
299- and is_torch_xla_available
300- and is_torch_xla_version ('>' , '2.2' )
301- and (not is_spmd () or is_torch_xla_version ('>' , '2.3' ))
302- ):
303- processor = XLAFlashAttnProcessor2_0 (partition_spec )
297+ if use_xla_flash_attention :
298+ if not is_torch_xla_available :
299+ raise "torch_xla is not available"
300+ elif is_torch_xla_version ("<" , "2.3" ):
301+ raise "flash attention pallas kernel is supported from torch_xla version 2.3"
302+ elif is_spmd () and is_torch_xla_version ("<" , "2.4" ):
303+ raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
304+ else :
305+ processor = XLAFlashAttnProcessor2_0 (partition_spec )
304306 else :
305307 processor = (
306308 AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk else AttnProcessor ()
@@ -2871,6 +2873,7 @@ def __call__(
28712873 partition_spec = self .partition_spec if is_spmd () else None
28722874 hidden_states = flash_attention (query , key , value , causal = False , partition_spec = partition_spec )
28732875 else :
2876+ logger .warning (f"Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096." )
28742877 hidden_states = F .scaled_dot_product_attention (
28752878 query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
28762879 )
0 commit comments