File tree Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change 2121
2222from ..image_processor import IPAdapterMaskProcessor
2323from ..utils import deprecate , is_torch_xla_available , logging
24- from ..utils .import_utils import is_torch_npu_available , is_xformers_available , is_torch_xla_version
24+ from ..utils .import_utils import is_torch_npu_available , is_torch_xla_version , is_xformers_available
2525from ..utils .torch_utils import is_torch_version , maybe_allow_in_graph
2626
2727
3939if is_torch_xla_available ():
4040 # flash attention pallas kernel is introduced in the torch_xla 2.3 release.
4141 if is_torch_xla_version (">" , "2.2" ):
42- from torch_xla .runtime import is_spmd
4342 from torch_xla .experimental .custom_kernel import flash_attention
43+ from torch_xla .runtime import is_spmd
4444 XLA_AVAILABLE = True
4545else :
4646 XLA_AVAILABLE = False
Original file line number Diff line number Diff line change @@ -226,13 +226,13 @@ def fn_recursive_set_flash_attention(module: torch.nn.Module):
226226 fn_recursive_set_flash_attention (module )
227227
228228 def enable_xla_flash_attention (self , partition_spec : Optional [Callable ] = None ):
229- r"""
229+ r"""
230230 Enable the flash attention pallals kernel for torch_xla.
231231 """
232232 self .set_use_xla_flash_attention (True , partition_spec )
233233
234234 def disable_xla_flash_attention (self ):
235- r"""
235+ r"""
236236 Disable the flash attention pallals kernel for torch_xla.
237237 """
238238 self .set_use_xla_flash_attention (False )
You can’t perform that action at this time.
0 commit comments