Skip to content

Commit 1913913

Browse files
Liyang90carmocca
andauthored
Fix for hanging issue on TPU Pod (#16844)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 235e692 commit 1913913

File tree

6 files changed

+20
-13
lines changed

6 files changed

+20
-13
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6464

6565
- Fixed issue where the wrapped dataloader `iter()` would be called twice ([#16841](https://github.com/Lightning-AI/lightning/pull/16841))
6666

67+
- Fixed DDP spawn hang on TPU Pods ([#16844](https://github.com/Lightning-AI/lightning/pull/16844))
68+
6769
- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
6870
- Fixed parsing of defaults for `--accelerator` and `--precision` in Fabric CLI when `accelerator` and `precision` are set to non-default values in the code ([#16818](https://github.com/Lightning-AI/lightning/pull/16818))
6971

src/lightning/fabric/strategies/launchers/xla.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,21 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
7777

7878
def _wrapping_function(
7979
self,
80+
# XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing
81+
# https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/distributed/xla_multiprocessing.py#L321
8082
process_idx: int,
8183
function: Callable,
8284
args: Any,
8385
kwargs: Any,
8486
return_queue: SimpleQueue,
8587
global_states: Optional[_GlobalStateSnapshot] = None,
8688
) -> None:
87-
self._strategy._local_rank = process_idx
8889
results = function(*args, **kwargs)
8990

90-
if process_idx == 0:
91+
if self._strategy.local_rank == 0:
9192
return_queue.put(move_data_to_device(results, "cpu"))
9293

93-
_rank_teardown(process_idx)
94+
_rank_teardown(self._strategy.local_rank)
9495

9596

9697
def _rank_teardown(rank: int) -> None:

src/lightning/fabric/strategies/xla.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(
5959
self._checkpoint_io: Optional[CheckpointIO]
6060
self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call
6161
self._launched = False
62-
self._local_rank = 0
6362

6463
@property
6564
def root_device(self) -> torch.device:
@@ -73,10 +72,6 @@ def root_device(self) -> torch.device:
7372
def num_processes(self) -> int:
7473
return len(self.parallel_devices) if self.parallel_devices is not None else 0
7574

76-
@property
77-
def local_rank(self) -> int:
78-
return self._local_rank
79-
8075
@property
8176
def checkpoint_io(self) -> CheckpointIO:
8277
if self._checkpoint_io is None:
@@ -214,8 +209,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
214209
def _set_world_ranks(self) -> None:
215210
if self.cluster_environment is None:
216211
return
217-
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
218-
self.cluster_environment.set_world_size(self.num_processes)
219212
rank_zero_only.rank = self.cluster_environment.global_rank()
220213

221214
@staticmethod

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
370370

371371
- Fixed an issue where `DistributedSampler.set_epoch` wasn't getting called during `trainer.predict` ([#16785](https://github.com/Lightning-AI/lightning/pull/16785), [#16826](https://github.com/Lightning-AI/lightning/pull/16826))
372372

373+
- Fixed DDP spawn hang on TPU Pods ([#16844](https://github.com/Lightning-AI/lightning/pull/16844))
373374

374375
- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
375376

src/lightning/pytorch/strategies/launchers/xla.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
9595

9696
def _wrapping_function(
9797
self,
98+
# XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing
99+
# https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/distributed/xla_multiprocessing.py#L321
98100
process_idx: int,
99101
trainer: Optional["pl.Trainer"],
100102
function: Callable,
@@ -103,16 +105,15 @@ def _wrapping_function(
103105
return_queue: SimpleQueue,
104106
global_states: Optional[_GlobalStateSnapshot] = None,
105107
) -> None:
106-
self._strategy._local_rank = process_idx
107108
results = function(*args, **kwargs)
108109

109110
if trainer is not None:
110111
results = self._collect_rank_zero_results(trainer, results)
111112

112-
if process_idx == 0:
113+
if self._strategy.local_rank == 0:
113114
return_queue.put(move_data_to_device(results, "cpu"))
114115

115-
_rank_teardown(process_idx)
116+
_rank_teardown(self._strategy.local_rank)
116117

117118
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
118119
rank_zero_debug("Collecting results from rank 0 process.")

src/lightning/pytorch/strategies/xla.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def root_device(self) -> torch.device:
9393

9494
return xm.xla_device()
9595

96+
@property
97+
def local_rank(self) -> int:
98+
return self.cluster_environment.local_rank() if self.cluster_environment is not None else 0
99+
96100
@staticmethod
97101
def _validate_dataloader(dataloader: object) -> None:
98102
if not has_len(dataloader):
@@ -210,6 +214,11 @@ def setup_distributed(self) -> None:
210214
self.set_world_ranks()
211215
rank_zero_only.rank = self.global_rank
212216

217+
def set_world_ranks(self) -> None:
218+
if self.cluster_environment is None:
219+
return
220+
rank_zero_only.rank = self.cluster_environment.global_rank()
221+
213222
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
214223
assert self.model is not None
215224
with self.precision_plugin.val_step_context():

0 commit comments

Comments
 (0)