Skip to content

Commit ee41f82

Browse files
author
AlexandrByzov
committed
feat: add tests for world_size, global_ordinal, local_ordinal
1 parent 4e8e86c commit ee41f82

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

tests/tests_fabric/plugins/environments/test_xla.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,34 @@ def test_detect(monkeypatch):
9797

9898
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: True)
9999
assert XLAEnvironment.detect()
100+
101+
102+
@mock.patch.dict(os.environ, {}, clear=True)
103+
def test_attributes_from_xla_greater_21_used(xla_available, monkeypatch):
104+
"""Test XLA environment attributes when using XLA runtime >= 2.1."""
105+
monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_GREATER_EQUAL_2_1", True)
106+
monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_GREATER_EQUAL_2_1", True)
107+
108+
env = XLAEnvironment()
109+
110+
with (
111+
mock.patch("torch_xla.runtime.world_size", return_value=4),
112+
mock.patch("torch_xla.runtime.global_ordinal", return_value=2),
113+
mock.patch("torch_xla.runtime.local_ordinal", return_value=1),
114+
):
115+
env.world_size.cache_clear()
116+
env.global_rank.cache_clear()
117+
env.local_rank.cache_clear()
118+
119+
assert env.world_size() == 4
120+
assert env.global_rank() == 2
121+
assert env.local_rank() == 1
122+
123+
env.set_world_size(100)
124+
assert env.world_size() == 4
125+
126+
env.set_global_rank(100)
127+
assert env.global_rank() == 2
128+
129+
env.set_local_rank(100)
130+
assert env.local_rank() == 1

0 commit comments

Comments
 (0)