Skip to content

Commit 5c3de46

Browse files
Liyang90carmocca
authored andcommitted
Fix for hanging issue on TPU Pod (#16844)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent f80f2f9 commit 5c3de46

File tree

6 files changed

+23
-15
lines changed

6 files changed

+23
-15
lines changed

src/lightning_fabric/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
## [1.9.4] - 2023-02-28
99

10-
No changes.
10+
### Fixed
11+
12+
- Fixed DDP spawn hang on TPU Pods ([#16844](https://github.com/Lightning-AI/lightning/pull/16844))
1113

1214

1315
## [1.9.3] - 2023-02-21

src/lightning_fabric/strategies/launchers/xla.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,21 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
7777

7878
def _wrapping_function(
7979
self,
80+
# XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing
81+
# https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/distributed/xla_multiprocessing.py#L321
8082
process_idx: int,
8183
function: Callable,
8284
args: Any,
8385
kwargs: Any,
8486
return_queue: SimpleQueue,
8587
global_states: Optional[_GlobalStateSnapshot] = None,
8688
) -> None:
87-
self._strategy._local_rank = process_idx
8889
results = function(*args, **kwargs)
8990

90-
if process_idx == 0:
91+
if self._strategy.local_rank == 0:
9192
return_queue.put(move_data_to_device(results, "cpu"))
9293

93-
_rank_teardown(process_idx)
94+
_rank_teardown(self._strategy.local_rank)
9495

9596

9697
def _rank_teardown(rank: int) -> None:

src/lightning_fabric/strategies/xla.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(
5959
self._checkpoint_io: Optional[CheckpointIO]
6060
self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call
6161
self._launched = False
62-
self._local_rank = 0
6362

6463
@property
6564
def root_device(self) -> torch.device:
@@ -73,10 +72,6 @@ def root_device(self) -> torch.device:
7372
def num_processes(self) -> int:
7473
return len(self.parallel_devices) if self.parallel_devices is not None else 0
7574

76-
@property
77-
def local_rank(self) -> int:
78-
return self._local_rank
79-
8075
@property
8176
def checkpoint_io(self) -> CheckpointIO:
8277
if self._checkpoint_io is None:
@@ -214,8 +209,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
214209
def _set_world_ranks(self) -> None:
215210
if self.cluster_environment is None:
216211
return
217-
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
218-
self.cluster_environment.set_world_size(self.num_processes)
219212
rank_zero_only.rank = self.cluster_environment.global_rank()
220213

221214
@staticmethod

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
## [1.9.4] - 2023-02-28
99

10-
No changes.
10+
### Fixed
11+
12+
- Fixed DDP spawn hang on TPU Pods ([#16844](https://github.com/Lightning-AI/lightning/pull/16844))
1113

1214

1315
## [1.9.3] - 2023-02-21

src/pytorch_lightning/strategies/launchers/xla.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
8888

8989
def _wrapping_function(
9090
self,
91+
# XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing
92+
# https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/distributed/xla_multiprocessing.py#L321
9193
process_idx: int,
9294
trainer: Optional["pl.Trainer"],
9395
function: Callable,
@@ -96,16 +98,15 @@ def _wrapping_function(
9698
return_queue: SimpleQueue,
9799
global_states: Optional[_GlobalStateSnapshot] = None,
98100
) -> None:
99-
self._strategy._local_rank = process_idx
100101
results = function(*args, **kwargs)
101102

102103
if trainer is not None:
103104
results = self._collect_rank_zero_results(trainer, results)
104105

105-
if process_idx == 0:
106+
if self._strategy.local_rank == 0:
106107
return_queue.put(move_data_to_device(results, "cpu"))
107108

108-
_rank_teardown(process_idx)
109+
_rank_teardown(self._strategy.local_rank)
109110

110111
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
111112
rank_zero_debug("Collecting results from rank 0 process.")

src/pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def root_device(self) -> torch.device:
9797

9898
return xm.xla_device()
9999

100+
@property
101+
def local_rank(self) -> int:
102+
return self.cluster_environment.local_rank() if self.cluster_environment is not None else 0
103+
100104
@staticmethod
101105
def _validate_dataloader(dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None:
102106
def check_has_len(dataloader: DataLoader) -> None:
@@ -234,6 +238,11 @@ def setup_distributed(self) -> None:
234238
self.set_world_ranks()
235239
rank_zero_only.rank = self.global_rank
236240

241+
def set_world_ranks(self) -> None:
242+
if self.cluster_environment is None:
243+
return
244+
rank_zero_only.rank = self.cluster_environment.global_rank()
245+
237246
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
238247
assert self.model is not None
239248
with self.precision_plugin.val_step_context():

0 commit comments

Comments
 (0)