Skip to content

Commit 0afd4e1

Browse files
Replace using_pjrt() xla runtime device_type() check with in xla.py for torch-xla>=2.5 (#20442)
* Replace `using_pjrt()` xla runtime `device_type()` check with in xla.py Fixes #20419 `torch_xla.runtime.using_pjrt()` is removed in pytorch/xla#7787 This PR replaces references to that function with a check to [`device_type()`](https://github.com/pytorch/xla/blob/master/torch_xla/runtime.py#L83) to recreate the behavior of that function, minus the manual initialization * Added tests/refactored for version compat * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * precommit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1e88899 commit 0afd4e1

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

src/lightning/fabric/accelerators/xla.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,21 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
102102
# PJRT support requires this minimum version
103103
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
104104
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
105+
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
105106

106107

107108
def _using_pjrt() -> bool:
109+
# `using_pjrt` is removed in torch_xla 2.5
110+
if _XLA_GREATER_EQUAL_2_5:
111+
from torch_xla import runtime as xr
112+
113+
return xr.device_type() is not None
108114
# delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
109115
if _XLA_GREATER_EQUAL_2_1:
110116
from torch_xla import runtime as xr
111117

112118
return xr.using_pjrt()
119+
113120
from torch_xla.experimental import pjrt
114121

115122
return pjrt.using_pjrt()

tests/tests_fabric/accelerators/test_xla.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,8 @@ def test_get_parallel_devices_raises(tpu_available):
4444
XLAAccelerator.get_parallel_devices(5)
4545
with pytest.raises(ValueError, match="Could not parse.*anything-else'"):
4646
XLAAccelerator.get_parallel_devices("anything-else")
47+
48+
49+
@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present")
50+
def test_instantiate_xla_accelerator():
51+
_ = XLAAccelerator()

0 commit comments

Comments
 (0)