Skip to content

Commit a53a2e1

Browse files
committed
apply suggestions: split tests
1 parent 81ac311 commit a53a2e1

File tree

1 file changed

+41
-8
lines changed

1 file changed

+41
-8
lines changed

tests/tests_fabric/plugins/environments/test_xla.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,24 +102,57 @@ def test_detect(monkeypatch):
102102
@mock.patch.dict(os.environ, {}, clear=True)
103103
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
104104
@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
105-
def test_attributes_from_xla_greater_21_used(xla_available, monkeypatch):
106-
"""Test XLA environment attributes when using XLA runtime >= 2.1."""
105+
def test_world_size_from_xla_runtime_greater_2_1(xla_available):
106+
"""Test that world_size uses torch_xla.runtime when XLA >= 2.1."""
107+
env = XLAEnvironment()
108+
109+
with mock.patch("torch_xla.runtime.world_size", return_value=4) as mock_world_size:
110+
env.world_size.cache_clear()
111+
assert env.world_size() == 4
112+
mock_world_size.assert_called_once()
107113

114+
115+
@mock.patch.dict(os.environ, {}, clear=True)
116+
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
117+
@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
118+
def test_global_rank_from_xla_runtime_greater_2_1(xla_available):
119+
"""Test that global_rank uses torch_xla.runtime when XLA >= 2.1."""
120+
env = XLAEnvironment()
121+
122+
with mock.patch("torch_xla.runtime.global_ordinal", return_value=2) as mock_global_ordinal:
123+
env.global_rank.cache_clear()
124+
assert env.global_rank() == 2
125+
mock_global_ordinal.assert_called_once()
126+
127+
128+
@mock.patch.dict(os.environ, {}, clear=True)
129+
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
130+
@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
131+
def test_local_rank_from_xla_runtime_greater_2_1(xla_available):
132+
"""Test that local_rank uses torch_xla.runtime when XLA >= 2.1."""
133+
env = XLAEnvironment()
134+
135+
with mock.patch("torch_xla.runtime.local_ordinal", return_value=1) as mock_local_ordinal:
136+
env.local_rank.cache_clear()
137+
assert env.local_rank() == 1
138+
mock_local_ordinal.assert_called_once()
139+
140+
141+
@mock.patch.dict(os.environ, {}, clear=True)
142+
@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True)
143+
@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True)
144+
def test_setters_readonly_when_xla_runtime_greater_2_1(xla_available):
145+
"""Test that set_world_size and set_global_rank don't affect values when using XLA runtime >= 2.1."""
108146
env = XLAEnvironment()
109147

110148
with (
111149
mock.patch("torch_xla.runtime.world_size", return_value=4),
112150
mock.patch("torch_xla.runtime.global_ordinal", return_value=2),
113-
mock.patch("torch_xla.runtime.local_ordinal", return_value=1),
114151
):
115152
env.world_size.cache_clear()
116153
env.global_rank.cache_clear()
117-
env.local_rank.cache_clear()
118-
119-
assert env.world_size() == 4
120-
assert env.global_rank() == 2
121-
assert env.local_rank() == 1
122154

155+
# Values should come from XLA runtime and not be affected by setters
123156
env.set_world_size(100)
124157
assert env.world_size() == 4
125158

0 commit comments

Comments
 (0)