diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index 38d7380dc7905..9f620e6a4e3a4 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -102,14 +102,21 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No # PJRT support requires this minimum version _XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla") _XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1") +_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5") def _using_pjrt() -> bool: + # `using_pjrt` is removed in torch_xla 2.5 + if _XLA_GREATER_EQUAL_2_5: + from torch_xla import runtime as xr + + return xr.device_type() is not None # delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped. if _XLA_GREATER_EQUAL_2_1: from torch_xla import runtime as xr return xr.using_pjrt() + from torch_xla.experimental import pjrt return pjrt.using_pjrt() diff --git a/tests/tests_fabric/accelerators/test_xla.py b/tests/tests_fabric/accelerators/test_xla.py index 1af7d7e1e7206..7a906c8ae0c54 100644 --- a/tests/tests_fabric/accelerators/test_xla.py +++ b/tests/tests_fabric/accelerators/test_xla.py @@ -44,3 +44,8 @@ def test_get_parallel_devices_raises(tpu_available): XLAAccelerator.get_parallel_devices(5) with pytest.raises(ValueError, match="Could not parse.*anything-else'"): XLAAccelerator.get_parallel_devices("anything-else") + + +@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present") +def test_instantiate_xla_accelerator(): + _ = XLAAccelerator()