Skip to content

Commit 83925f1

Browse files
ananthsublexierule
authored andcommitted
Skip reconciliate_processes if used within a cluster environment that creates processes externally (#9389)
* [RFC] Skip reconciliate_processes if used within a cluster environment that creates processes externally
1 parent 9f95a92 commit 83925f1

File tree

5 files changed

+37
-12
lines changed

5 files changed

+37
-12
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [1.4.8] - 2021-09-21
9+
10+
- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent (#9389)
11+
- Added PL_RECONCILE_PROCESS environment variable to enable process reconciliation regardless of cluster environment settings (#9389)
12+
13+
814
## [1.4.7] - 2021-09-14
915

1016
- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364))
@@ -34,6 +40,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3440
- Fixed signature of `Timer.on_train_epoch_end` and `StochasticWeightAveraging.on_train_epoch_end` to prevent unwanted deprecation warnings ([#9347](https://github.com/PyTorchLightning/pytorch-lightning/pull/9347))
3541

3642

43+
- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))
44+
45+
3746
## [1.4.5] - 2021-08-31
3847

3948
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def __init__(
107107
self._ddp_comm_wrapper = ddp_comm_wrapper
108108
self._pids: Optional[List[int]] = None
109109
self._sync_dir: Optional[str] = None
110+
self._rank_0_has_called_call_children_scripts: bool = False
110111
self.set_world_ranks()
111112

112113
@property
@@ -235,6 +236,8 @@ def _call_children_scripts(self):
235236
delay = np.random.uniform(1, 5, 1)[0]
236237
sleep(delay)
237238

239+
self._rank_0_has_called_call_children_scripts = True
240+
238241
def setup_distributed(self):
239242
reset_seed()
240243

@@ -331,7 +334,9 @@ def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Opt
331334

332335
def pre_dispatch(self):
333336
# share ddp pids to all processes
334-
self._share_information_to_prevent_deadlock()
337+
self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts)
338+
if self._should_run_deadlock_detection():
339+
self._share_information_to_prevent_deadlock()
335340

336341
# move the model to the correct device
337342
self.model_to_device()
@@ -405,7 +410,16 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
405410
find_unused_parameters=False,
406411
)
407412

408-
def _share_information_to_prevent_deadlock(self):
413+
def _should_run_deadlock_detection(self) -> bool:
414+
"""Determines whether the plugin will perform process reconciliation in case of errors.
415+
416+
If the environment variable `PL_RECONCILE_PROCESS` is set, run detection regardless of the cluster environment.
417+
By default this is disabled. Otherwise, if the cluster environment creates the processes, allow the scheduler /
418+
parent process to perform the process termination, external to Lightning.
419+
"""
420+
return os.getenv("PL_RECONCILE_PROCESS", "0") == "1" or self._rank_0_has_called_call_children_scripts
421+
422+
def _share_information_to_prevent_deadlock(self) -> None:
409423
self._share_pids()
410424

411425
# there should be a unique sync_dir per nodes.
@@ -421,19 +435,20 @@ def _share_information_to_prevent_deadlock(self):
421435

422436
self._sync_dir = sync_dirs[self.node_rank]
423437

424-
def _share_pids(self):
425-
"""
426-
Make all DDP processes aware of all processes pids.
427-
"""
438+
def _share_pids(self) -> None:
439+
"""Make all DDP processes aware of all processes pids."""
428440
self.barrier()
429441
pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device))
430442
pids = pids.cpu().numpy().tolist()
431443
self._pids = pids if isinstance(pids, list) else [pids]
432444

433-
def reconciliate_processes(self, trace: str):
445+
def reconciliate_processes(self, trace: str) -> None:
434446
if self.world_size < 2:
435447
return
436448

449+
if not self._should_run_deadlock_detection():
450+
return
451+
437452
sync_dir = self._sync_dir
438453

439454
if not sync_dir:

tests/plugins/environments/torch_elastic_deadlock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
88
from tests.helpers.boring_model import BoringModel
99

10-
if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1":
10+
if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1" and os.getenv("PL_RECONCILE_PROCESS", "0") == "1":
1111

1212
class CustomException(Exception):
1313
pass

tests/special_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ fi
8080

8181
# TODO: enable when CI uses torch>=1.9
8282
# test deadlock is properly handled with TorchElastic.
83-
# LOGS=$(PL_RUNNING_SPECIAL_TESTS=1 python -m torch.distributed.run --nproc_per_node=2 --max_restarts 0 -m coverage run --source pytorch_lightning -a tests/plugins/environments/torch_elastic_deadlock.py | grep "SUCCEEDED")
83+
# LOGS=$(PL_RUNNING_SPECIAL_TESTS=1 PL_RECONCILE_PROCESS=1 python -m torch.distributed.run --nproc_per_node=2 --max_restarts 0 -m coverage run --source pytorch_lightning -a tests/plugins/environments/torch_elastic_deadlock.py | grep "SUCCEEDED")
8484
# if [ -z "$LOGS" ]; then
8585
# exit 1
8686
# fi

tests/trainer/test_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,13 +1818,14 @@ def test_exception_when_lightning_module_is_not_set_on_trainer():
18181818
trainer.predict()
18191819

18201820

1821+
class CustomException(Exception):
1822+
pass
1823+
1824+
18211825
@RunIf(min_gpus=2, special=True)
18221826
def test_ddp_terminate_when_deadlock_is_detected(tmpdir):
18231827
"""Test that DDP kills the remaining processes when only one rank is throwing an exception."""
18241828

1825-
class CustomException(Exception):
1826-
pass
1827-
18281829
class TestModel(BoringModel):
18291830
def training_step(self, batch, batch_idx):
18301831
if batch_idx == 1 and self.trainer.is_global_zero:

0 commit comments

Comments
 (0)