@@ -284,7 +284,9 @@ def __init__(
284284 )
285285 self .set_processor (processor )
286286
287- def set_use_xla_flash_attention (self , use_xla_flash_attention : bool , partition_spec : Optional [Tuple [Optional [str ], ...]] = None ) -> None :
287+ def set_use_xla_flash_attention (
288+ self , use_xla_flash_attention : bool , partition_spec : Optional [Tuple [Optional [str ], ...]] = None
289+ ) -> None :
288290 r"""
289291 Set whether to use xla flash attention from `torch_xla` or not.
290292
@@ -296,7 +298,7 @@ def set_use_xla_flash_attention(self, use_xla_flash_attention: bool, partition_s
296298 """
297299 if use_xla_flash_attention :
298300 if not is_torch_xla_available :
299- raise "torch_xla is not available"
301+ raise "torch_xla is not available"
300302 elif is_torch_xla_version ("<" , "2.3" ):
301303 raise "flash attention pallas kernel is supported from torch_xla version 2.3"
302304 elif is_spmd () and is_torch_xla_version ("<" , "2.4" ):
@@ -2794,12 +2796,14 @@ class XLAFlashAttnProcessor2_0:
27942796
27952797 def __init__ (self , partition_spec : Optional [Tuple [Optional [str ], ...]] = None ):
27962798 if not hasattr (F , "scaled_dot_product_attention" ):
2797- raise ImportError ("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
2799+ raise ImportError (
2800+ "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2801+ )
27982802 if is_torch_xla_version ("<" , "2.3" ):
27992803 raise ImportError ("XLA flash attention requires torch_xla version >= 2.3." )
28002804 if is_spmd () and is_torch_xla_version ("<" , "2.4" ):
28012805 raise ImportError ("SPMD support for XLA flash attention needs torch_xla version >= 2.4." )
2802- self .partition_spec = partition_spec
2806+ self .partition_spec = partition_spec
28032807
28042808 def __call__ (
28052809 self ,
@@ -2875,7 +2879,9 @@ def __call__(
28752879 partition_spec = self .partition_spec if is_spmd () else None
28762880 hidden_states = flash_attention (query , key , value , causal = False , partition_spec = partition_spec )
28772881 else :
2878- logger .warning ("Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096." )
2882+ logger .warning (
2883+ "Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
2884+ )
28792885 hidden_states = F .scaled_dot_product_attention (
28802886 query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
28812887 )
0 commit comments