Skip to content

Commit 41a5c2a

Browse files
carmoccalexierule
authored andcommitted
Fix XLAEnvironment detection on TPU pod (#16806)
tpu fixes
1 parent fc5bab6 commit 41a5c2a

File tree

10 files changed

+90
-23
lines changed

10 files changed

+90
-23
lines changed

src/lightning_fabric/CHANGELOG.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,26 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [1.9.3] - 2023-MM-DD
9+
10+
### Added
11+
12+
13+
### Changed
14+
15+
16+
### Deprecated
17+
18+
-
19+
20+
21+
### Removed
22+
23+
### Fixed
24+
25+
- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
26+
27+
828
## [1.9.2] - 2023-02-15
929

1030
- Fixed an attribute error and improved input validation for invalid strategy types being passed to Fabric ([#16693](https://github.com/Lightning-AI/lightning/pull/16693))

src/lightning_fabric/connector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,9 @@ def _lazy_init_strategy(self) -> None:
511511
if self.checkpoint_io:
512512
self.strategy.checkpoint_io = self.checkpoint_io
513513
if hasattr(self.strategy, "cluster_environment"):
514-
self.strategy.cluster_environment = self.cluster_environment
514+
if self.strategy.cluster_environment is None:
515+
self.strategy.cluster_environment = self.cluster_environment
516+
self.cluster_environment = self.strategy.cluster_environment
515517
if hasattr(self.strategy, "parallel_devices"):
516518
if self.strategy.parallel_devices:
517519
self._parallel_devices = self.strategy.parallel_devices

src/pytorch_lightning/CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,20 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [1.9.3] - YYYY-MM-DD
9+
10+
### Fixed
11+
12+
- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
13+
14+
815
## [1.9.2] - 2023-02-15
916

1017
### Changed
1118

1219
- Disabled strict loading in multiprocessing launcher ("ddp_spawn", etc.) when loading weights back into the main process ([#16365](https://github.com/Lightning-AI/lightning/pull/16365))
1320

21+
1422
### Fixed
1523

1624
- Fixed an attribute error and improved input validation for invalid strategy types being passed to Trainer ([#16693](https://github.com/Lightning-AI/lightning/pull/16693))

src/pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def _check_config_and_set_final_flags(
276276
if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
277277
raise ValueError(
278278
f"You selected an invalid strategy name: `strategy={strategy!r}`."
279-
" It must be either a string or an instance of `lightning.pytorch.strategies.Strategy`."
279+
" It must be either a string or an instance of `pytorch_lightning.strategies.Strategy`."
280280
" Example choices: ddp, ddp_spawn, deepspeed, dp, ..."
281281
" Find a complete list of options in our documentation at https://lightning.ai"
282282
)
@@ -821,7 +821,9 @@ def _lazy_init_strategy(self) -> None:
821821
if self.checkpoint_io:
822822
self.strategy.checkpoint_io = self.checkpoint_io
823823
if hasattr(self.strategy, "cluster_environment"):
824-
self.strategy.cluster_environment = self.cluster_environment
824+
if self.strategy.cluster_environment is None:
825+
self.strategy.cluster_environment = self.cluster_environment
826+
self.cluster_environment = self.strategy.cluster_environment
825827
if hasattr(self.strategy, "parallel_devices"):
826828
if self.strategy.parallel_devices:
827829
self._parallel_devices = self.strategy.parallel_devices

tests/tests_fabric/test_connector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
LSFEnvironment,
3737
SLURMEnvironment,
3838
TorchElasticEnvironment,
39+
XLAEnvironment,
3940
)
4041
from lightning_fabric.plugins.io import TorchCheckpointIO
4142
from lightning_fabric.strategies import (
@@ -69,6 +70,8 @@ def test_accelerator_choice_tpu(accelerator, devices):
6970
# accelerator=tpu, devices=None (default) maps to devices=auto (8) and then chooses XLAStrategy
7071
# This behavior may change in the future: https://github.com/Lightning-AI/lightning/issues/10606
7172
assert isinstance(connector.strategy, XLAStrategy)
73+
assert isinstance(connector.strategy.cluster_environment, XLAEnvironment)
74+
assert isinstance(connector.cluster_environment, XLAEnvironment)
7275
else:
7376
assert isinstance(connector.strategy, SingleTPUStrategy)
7477

tests/tests_pytorch/accelerators/test_tpu.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def test_if_test_works_after_train(tmpdir):
8181

8282

8383
@RunIf(skip_windows=True)
84-
def test_accelerator_cpu_with_tpu_cores_flag(tpu_available):
84+
@mock.patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.set_world_ranks")
85+
def test_accelerator_cpu_with_tpu_cores_flag(_, tpu_available):
8586
assert TPUAccelerator.is_available()
8687

8788
trainer = Trainer(accelerator="cpu", devices=8)
@@ -94,7 +95,8 @@ def test_accelerator_cpu_with_tpu_cores_flag(tpu_available):
9495

9596
@RunIf(skip_windows=True)
9697
@pytest.mark.parametrize(["accelerator", "devices"], [("auto", 8), ("auto", "auto"), ("tpu", None)])
97-
def test_accelerator_tpu(accelerator, devices, tpu_available):
98+
@mock.patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.set_world_ranks")
99+
def test_accelerator_tpu(_, accelerator, devices, tpu_available):
98100
assert TPUAccelerator.is_available()
99101

100102
trainer = Trainer(accelerator=accelerator, devices=devices)
@@ -104,7 +106,8 @@ def test_accelerator_tpu(accelerator, devices, tpu_available):
104106

105107

106108
@RunIf(skip_windows=True)
107-
def test_accelerator_tpu_with_tpu_cores_priority(tpu_available):
109+
@mock.patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.set_world_ranks")
110+
def test_accelerator_tpu_with_tpu_cores_priority(_, tpu_available):
108111
"""Test for checking `tpu_cores` flag takes priority over `devices`."""
109112
tpu_cores = 8
110113
with pytest.warns(UserWarning, match="The flag `devices=1` will be ignored,"):
@@ -115,7 +118,8 @@ def test_accelerator_tpu_with_tpu_cores_priority(tpu_available):
115118

116119

117120
@RunIf(skip_windows=True)
118-
def test_set_devices_if_none_tpu(tpu_available):
121+
@mock.patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.set_world_ranks")
122+
def test_set_devices_if_none_tpu(_, tpu_available):
119123
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
120124
trainer = Trainer(accelerator="tpu", tpu_cores=8)
121125
assert isinstance(trainer.accelerator, TPUAccelerator)
@@ -202,7 +206,8 @@ def test_strategy_choice_tpu_str_ddp_spawn(tpu_available):
202206

203207

204208
@RunIf(skip_windows=True)
205-
def test_strategy_choice_tpu_str_tpu_spawn_debug(tpu_available):
209+
@mock.patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.set_world_ranks")
210+
def test_strategy_choice_tpu_str_tpu_spawn_debug(_, tpu_available):
206211
trainer = Trainer(strategy="tpu_spawn_debug", accelerator="tpu", devices=8)
207212
assert isinstance(trainer.strategy, TPUSpawnStrategy)
208213

@@ -286,7 +291,8 @@ def test_tpu_invalid_raises_set_precision_with_strategy(tpu_available):
286291

287292

288293
@RunIf(skip_windows=True)
289-
def test_xla_checkpoint_plugin_being_default(tpu_available):
294+
@mock.patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.set_world_ranks")
295+
def test_xla_checkpoint_plugin_being_default(_, tpu_available):
290296
trainer = Trainer(accelerator="tpu", devices=8)
291297
assert isinstance(trainer.strategy.checkpoint_io, XLACheckpointIO)
292298

tests/tests_pytorch/deprecated_api/test_remove_2-0.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def test_v2_0_0_deprecated_gpus(cuda_count_4):
9090

9191

9292
@RunIf(skip_windows=True)
93-
def test_v2_0_0_deprecated_tpu_cores(tpu_available):
93+
@mock.patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.set_world_ranks")
94+
def test_v2_0_0_deprecated_tpu_cores(_, tpu_available):
9495
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
9596
_ = Trainer(tpu_cores=8)
9697

tests/tests_pytorch/strategies/test_registry.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest import mock
15+
1416
import pytest
1517

1618
from pytorch_lightning import Trainer
@@ -56,7 +58,8 @@ def test_deepspeed_strategy_registry_with_trainer(tmpdir, strategy):
5658

5759

5860
@RunIf(skip_windows=True)
59-
def test_tpu_spawn_debug_strategy_registry(xla_available):
61+
@mock.patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.set_world_ranks")
62+
def test_tpu_spawn_debug_strategy_registry(_, xla_available):
6063
strategy = "tpu_spawn_debug"
6164

6265
assert strategy in StrategyRegistry

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@
2828
LSFEnvironment,
2929
SLURMEnvironment,
3030
TorchElasticEnvironment,
31+
XLAEnvironment,
3132
)
3233
from pytorch_lightning import Trainer
3334
from pytorch_lightning.accelerators.accelerator import Accelerator
3435
from pytorch_lightning.accelerators.cpu import CPUAccelerator
3536
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
3637
from pytorch_lightning.accelerators.mps import MPSAccelerator
38+
from pytorch_lightning.accelerators.tpu import TPUAccelerator
3739
from pytorch_lightning.plugins import DoublePrecisionPlugin, LayerSync, NativeSyncBatchNorm, PrecisionPlugin
3840
from pytorch_lightning.plugins.io import TorchCheckpointIO
3941
from pytorch_lightning.strategies import (
@@ -45,6 +47,8 @@
4547
DDPStrategy,
4648
DeepSpeedStrategy,
4749
SingleDeviceStrategy,
50+
SingleTPUStrategy,
51+
TPUSpawnStrategy,
4852
)
4953
from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES
5054
from pytorch_lightning.strategies.hpu_parallel import HPUParallelStrategy
@@ -59,6 +63,24 @@ def test_accelerator_choice_cpu(tmpdir):
5963
assert isinstance(trainer.strategy, SingleDeviceStrategy)
6064

6165

66+
@RunIf(tpu=True, standalone=True)
67+
@pytest.mark.parametrize(
68+
["accelerator", "devices"], [("tpu", None), ("tpu", 1), ("tpu", [1]), ("tpu", 8), ("auto", 1), ("auto", 8)]
69+
)
70+
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
71+
def test_accelerator_choice_tpu(accelerator, devices):
72+
connector = AcceleratorConnector(accelerator=accelerator, devices=devices)
73+
assert isinstance(connector.accelerator, TPUAccelerator)
74+
if devices is None or (isinstance(devices, int) and devices > 1):
75+
# accelerator=tpu, devices=None (default) maps to devices=auto (8) and then chooses TPUSpawnStrategy
76+
# This behavior may change in the future: https://github.com/Lightning-AI/lightning/issues/10606
77+
assert isinstance(connector.strategy, TPUSpawnStrategy)
78+
assert isinstance(connector.strategy.cluster_environment, XLAEnvironment)
79+
assert isinstance(connector.cluster_environment, XLAEnvironment)
80+
else:
81+
assert isinstance(connector.strategy, SingleTPUStrategy)
82+
83+
6284
def test_accelerator_invalid_choice():
6385
with pytest.raises(ValueError, match="You selected an invalid accelerator name: `accelerator='invalid'`"):
6486
Trainer(accelerator="invalid")
@@ -265,7 +287,8 @@ def test_interactive_compatible_dp_strategy_gpu(mps_count_0, cuda_count_2, monke
265287

266288

267289
@RunIf(skip_windows=True)
268-
def test_interactive_compatible_strategy_tpu(tpu_available, monkeypatch):
290+
@mock.patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.set_world_ranks")
291+
def test_interactive_compatible_strategy_tpu(_, tpu_available, monkeypatch):
269292
monkeypatch.setattr(pytorch_lightning.trainer.connectors.accelerator_connector, "_IS_INTERACTIVE", True)
270293
trainer = Trainer(accelerator="tpu")
271294
assert trainer.strategy.launcher.is_interactive_compatible

tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import logging
1616
import os
1717
from unittest import mock
18-
from unittest.mock import PropertyMock
1918

2019
import pytest
2120
import torch
@@ -152,19 +151,19 @@ def test_num_stepping_batches_with_tpu_single():
152151
assert trainer.estimated_stepping_batches == len(model.train_dataloader())
153152

154153

154+
class MultiprocessModel(BoringModel):
155+
def on_train_start(self):
156+
assert self.trainer.world_size == 8
157+
assert self.trainer.estimated_stepping_batches == len(self.train_dataloader()) // 8
158+
159+
155160
@RunIf(tpu=True)
156-
@mock.patch(
157-
"pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.root_device",
158-
new_callable=PropertyMock,
159-
return_value=torch.device("xla:0"),
160-
)
161-
def test_num_stepping_batches_with_tpu_multi(_):
161+
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
162+
def test_num_stepping_batches_with_tpu_multi():
162163
"""Test stepping batches with the TPU strategy across multiple devices."""
163164
trainer = Trainer(accelerator="tpu", devices=8, max_epochs=1)
164-
model = BoringModel()
165-
trainer._data_connector.attach_data(model)
166-
trainer.strategy.connect(model)
167-
assert trainer.estimated_stepping_batches == len(model.train_dataloader()) // 8
165+
model = MultiprocessModel()
166+
trainer.fit(model)
168167

169168

170169
@mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True)

0 commit comments

Comments
 (0)