Skip to content

Commit a0ce930

Browse files
AlexandrByzovAlexandrByzovbhimrazyBorda
authored
bugfix: add support for global_ordinal, local_ordinal, world_size in xla (#20872)
* bugfix: add support for global_ordinal, local_ordinal, world_size in xla * fix: remove set local rank * Apply suggestions from code review --------- Co-authored-by: AlexandrByzov <[email protected]> Co-authored-by: Bhimraj Yadav <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 2f5c4b6 commit a0ce930

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
-
1818

19+
1920
### Changed
2021

2122
- Raise ValueError when seed is `out-of-bounds` or `cannot be cast to int` ([#21029](https://github.com/Lightning-AI/pytorch-lightning/pull/21029))
2223

2324

25+
### Fixed
26+
27+
- Fix XLA strategy to add support for `global_ordinal`, `local_ordinal`, `world_size` which came instead of deprecated methods ([#20852](https://github.com/Lightning-AI/pytorch-lightning/issues/20852))
28+
29+
2430
- fix: remove extra `name` parameter in accelerator registry decorator ([#20975](https://github.com/Lightning-AI/pytorch-lightning/pull/20975))
2531

2632

src/lightning/fabric/plugins/environments/xla.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def world_size(self) -> int:
6666
The output is cached for performance.
6767
6868
"""
69+
if _XLA_GREATER_EQUAL_2_1:
70+
from torch_xla import runtime as xr
71+
72+
return xr.world_size()
73+
6974
import torch_xla.core.xla_model as xm
7075

7176
return xm.xrt_world_size()
@@ -82,6 +87,11 @@ def global_rank(self) -> int:
8287
The output is cached for performance.
8388
8489
"""
90+
if _XLA_GREATER_EQUAL_2_1:
91+
from torch_xla import runtime as xr
92+
93+
return xr.global_ordinal()
94+
8595
import torch_xla.core.xla_model as xm
8696

8797
return xm.get_ordinal()
@@ -98,6 +108,11 @@ def local_rank(self) -> int:
98108
The output is cached for performance.
99109
100110
"""
111+
if _XLA_GREATER_EQUAL_2_1:
112+
from torch_xla import runtime as xr
113+
114+
return xr.local_ordinal()
115+
101116
import torch_xla.core.xla_model as xm
102117

103118
return xm.get_local_ordinal()

tests/tests_fabric/plugins/environments/test_xla.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,64 @@ def test_detect(monkeypatch):
9797

9898
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: True)
9999
assert XLAEnvironment.detect()
100+
101+
102+
@mock.patch.dict(os.environ, {}, clear=True)
103+
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
104+
@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
105+
def test_world_size_from_xla_runtime_greater_2_1(xla_available):
106+
"""Test that world_size uses torch_xla.runtime when XLA >= 2.1."""
107+
env = XLAEnvironment()
108+
109+
with mock.patch("torch_xla.runtime.world_size", return_value=4) as mock_world_size:
110+
env.world_size.cache_clear()
111+
assert env.world_size() == 4
112+
mock_world_size.assert_called_once()
113+
114+
115+
@mock.patch.dict(os.environ, {}, clear=True)
116+
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
117+
@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
118+
def test_global_rank_from_xla_runtime_greater_2_1(xla_available):
119+
"""Test that global_rank uses torch_xla.runtime when XLA >= 2.1."""
120+
env = XLAEnvironment()
121+
122+
with mock.patch("torch_xla.runtime.global_ordinal", return_value=2) as mock_global_ordinal:
123+
env.global_rank.cache_clear()
124+
assert env.global_rank() == 2
125+
mock_global_ordinal.assert_called_once()
126+
127+
128+
@mock.patch.dict(os.environ, {}, clear=True)
129+
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
130+
@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
131+
def test_local_rank_from_xla_runtime_greater_2_1(xla_available):
132+
"""Test that local_rank uses torch_xla.runtime when XLA >= 2.1."""
133+
env = XLAEnvironment()
134+
135+
with mock.patch("torch_xla.runtime.local_ordinal", return_value=1) as mock_local_ordinal:
136+
env.local_rank.cache_clear()
137+
assert env.local_rank() == 1
138+
mock_local_ordinal.assert_called_once()
139+
140+
141+
@mock.patch.dict(os.environ, {}, clear=True)
142+
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
143+
@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
144+
def test_setters_readonly_when_xla_runtime_greater_2_1(xla_available):
145+
"""Test that set_world_size and set_global_rank don't affect values when using XLA runtime >= 2.1."""
146+
env = XLAEnvironment()
147+
148+
with (
149+
mock.patch("torch_xla.runtime.world_size", return_value=4),
150+
mock.patch("torch_xla.runtime.global_ordinal", return_value=2),
151+
):
152+
env.world_size.cache_clear()
153+
env.global_rank.cache_clear()
154+
155+
# Values should come from XLA runtime and not be affected by setters
156+
env.set_world_size(100)
157+
assert env.world_size() == 4
158+
159+
env.set_global_rank(100)
160+
assert env.global_rank() == 2

0 commit comments

Comments
 (0)