Skip to content

Commit eb1356a

Browse files
tchatonpre-commit-ci[bot]awaelchliBorda
authored andcommitted
[bugfix] Add mechanism to prevent deadlock for DDP on Exception Trigger (#8167)
* add mechanism to prevent deadlock * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * resolve flake8 + update changelog * update on comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * remove space * resolve bugs * overwrite config * update on comments * update on comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * update * update test with comments * Update pytorch_lightning/plugins/training_type/parallel.py Co-authored-by: Jirka Borovec <[email protected]> * update on comments Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent d317ebf commit eb1356a

File tree

5 files changed

+102
-5
lines changed

5 files changed

+102
-5
lines changed

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
# limitations under the License.
1414
import logging
1515
import os
16+
import shutil
17+
import signal
1618
import subprocess
1719
import sys
20+
import tempfile
21+
import time
1822
from time import sleep
1923
from typing import Any, Dict, List, Optional, Union
2024

@@ -36,7 +40,7 @@
3640
rank_zero_warn,
3741
)
3842
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available
39-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
43+
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
4044
from pytorch_lightning.utilities.seed import reset_seed
4145

4246
if _HYDRA_AVAILABLE:
@@ -82,6 +86,8 @@ def __init__(
8286
self._ddp_comm_state = ddp_comm_state
8387
self._ddp_comm_hook = ddp_comm_hook
8488
self._ddp_comm_wrapper = ddp_comm_wrapper
89+
self._pids: Optional[List[int]] = None
90+
self._sync_dir: Optional[str] = None
8591
self.set_world_ranks()
8692

8793
@property
@@ -112,7 +118,6 @@ def setup_environment(self):
112118
self.setup_distributed()
113119

114120
def _call_children_scripts(self):
115-
116121
# bookkeeping of spawned processes
117122
assert self.local_rank == 0
118123
self._check_can_spawn_children()
@@ -126,6 +131,9 @@ def _call_children_scripts(self):
126131
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
127132
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
128133

134+
# create a temporary directory used to synchronize processes on deadlock.
135+
os.environ["PL_DDP_SYNC_TMPDIR"] = self._sync_dir = tempfile.mkdtemp()
136+
129137
# when user is using hydra find the absolute path
130138
path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path
131139

@@ -281,7 +289,8 @@ def pre_dispatch(self):
281289

282290
self.configure_ddp()
283291

284-
self.barrier()
292+
# share ddp pids to all processes
293+
self._share_information_to_prevent_deadlock()
285294

286295
def post_dispatch(self) -> None:
287296
self.cluster_environment.teardown()
@@ -344,3 +353,41 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
344353
description="DDP Plugin with `find_unused_parameters` as False",
345354
find_unused_parameters=False
346355
)
356+
357+
def _share_information_to_prevent_deadlock(self):
358+
self._share_pids()
359+
360+
# remove `PL_DDP_SYNC_TMPDIR` from os.environ
361+
self._sync_dir = os.environ.pop("PL_DDP_SYNC_TMPDIR", None)
362+
363+
def _share_pids(self):
364+
"""
365+
Make all DDP processes aware of all processes pids.
366+
"""
367+
self.barrier()
368+
pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device))
369+
pids = pids.cpu().numpy().tolist()
370+
self._pids = pids if isinstance(pids, list) else [pids]
371+
372+
def reconciliate_processes(self, trace: str):
373+
if self.world_size < 2:
374+
return
375+
376+
sync_dir = self._sync_dir
377+
378+
# save a file locally.
379+
torch.save(True, os.path.join(sync_dir, f"{self.global_rank}.pl"))
380+
381+
# sleep for a short time
382+
time.sleep(3)
383+
384+
# return if all processes wrote a file in the `sync_dir`.
385+
# todo (tchaton) Add support for non-shared file-system which will fail.
386+
if len(os.listdir(sync_dir)) == self.world_size:
387+
return
388+
389+
for pid in self._pids:
390+
if pid != os.getpid():
391+
os.kill(pid, signal.SIGKILL)
392+
shutil.rmtree(sync_dir)
393+
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")

pytorch_lightning/plugins/training_type/parallel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ def distributed_sampler_kwargs(self):
7676
distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank)
7777
return distributed_sampler_kwargs
7878

79+
def reconciliate_processes(self, trace: str):
80+
"""
81+
Function to re-conciliate processes on failure
82+
"""
83+
7984
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
8085
"""Perform a all_gather on all processes """
8186
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

pytorch_lightning/trainer/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Trainer to automate the training."""
1515
import logging
16+
import traceback
1617
import warnings
1718
from datetime import timedelta
1819
from itertools import count
@@ -61,6 +62,7 @@
6162
from pytorch_lightning.tuner.tuning import Tuner
6263
from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_warn
6364
from pytorch_lightning.utilities.debugging import InternalDebugger
65+
from pytorch_lightning.utilities.distributed import distributed_available
6466
from pytorch_lightning.utilities.exceptions import MisconfigurationException
6567
from pytorch_lightning.utilities.memory import recursive_detach
6668
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -902,6 +904,9 @@ def run_train(self) -> None:
902904
self.state.stage = None
903905
except BaseException:
904906
self.state.status = TrainerStatus.INTERRUPTED
907+
if distributed_available() and self.world_size > 1:
908+
# try syncing remaing processes, kill otherwise
909+
self.training_type_plugin.reconciliate_processes(traceback.format_exc())
905910
# give accelerators a chance to finish
906911
self.accelerator.on_train_end()
907912
# reset bookkeeping

pytorch_lightning/utilities/exceptions.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,12 @@
1414

1515

1616
class MisconfigurationException(Exception):
17-
pass
17+
"""
18+
Exception used to inform users of mis-use with PyTorch Lightning
19+
"""
20+
21+
22+
class DeadlockDetectedException(Exception):
23+
"""
24+
Exception used when a deadlock has been detected and processes are being killed
25+
"""

tests/trainer/test_trainer.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
4040
from pytorch_lightning.trainer.states import TrainerFn
4141
from pytorch_lightning.utilities.cloud_io import load as pl_load
42-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
42+
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
4343
from pytorch_lightning.utilities.seed import seed_everything
4444
from tests.base import EvalModelTemplate
4545
from tests.helpers import BoringModel, RandomDataset
@@ -2079,3 +2079,35 @@ def test_module_current_fx_attributes_reset(tmpdir):
20792079
assert (
20802080
model._current_dataloader_idx is None
20812081
), f"_current_dataloader_idx not reset after test: {model._current_dataloader_idx}"
2082+
2083+
2084+
@RunIf(min_gpus=2, special=True)
2085+
def test_ddp_terminate_when_deadlock_is_detected(tmpdir):
2086+
""" Test that DDP kills the remaining processes when only one rank is throwing an exception. """
2087+
2088+
class CustomException(Exception):
2089+
pass
2090+
2091+
class TestModel(BoringModel):
2092+
2093+
def training_step(self, batch, batch_idx):
2094+
if batch_idx == 1 and self.trainer.is_global_zero:
2095+
# rank 0: raises an exception
2096+
# rank 1: continues training but will hang on the next barrier in the training loop
2097+
raise CustomException
2098+
return super().training_step(batch, batch_idx)
2099+
2100+
model = TestModel()
2101+
2102+
trainer = Trainer(
2103+
default_root_dir=tmpdir,
2104+
max_epochs=1,
2105+
limit_train_batches=5,
2106+
num_sanity_val_steps=0,
2107+
gpus=2,
2108+
accelerator="ddp",
2109+
)
2110+
2111+
# simulate random failure in training_step on rank 0
2112+
with pytest.raises(DeadlockDetectedException, match="CustomException"):
2113+
trainer.fit(model)

0 commit comments

Comments
 (0)