Skip to content

Commit 7374bc8

Browse files
tchatoncarmocca
authored andcommitted
Torch Elastic DDP DeadLock bug fix (#8655)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 7c042f3 commit 7374bc8

File tree

4 files changed

+61
-9
lines changed

4 files changed

+61
-9
lines changed

.azure-pipelines/gpu-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
- bash: |
5252
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
5353
pip install fairscale>=0.3.4
54-
pip install "deepspeed>=0.4.0, !=0.4.4" # FIXME: bug with 0.4.4
54+
pip install "deepspeed>=0.4.3, !=0.4.4" # FIXME: bug with 0.4.4
5555
pip install . --requirement requirements/devel.txt
5656
pip list
5757
displayName: 'Install dependencies'

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,6 @@ def _call_children_scripts(self):
179179
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
180180
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
181181

182-
# create a temporary directory used to synchronize processes on deadlock.
183-
os.environ["PL_DDP_SYNC_TMPDIR"] = self._sync_dir = tempfile.mkdtemp()
184-
185182
# Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c`
186183
# See https://docs.python.org/3/reference/import.html#main-spec
187184
if __main__.__spec__ is None: # pragma: no-cover
@@ -410,8 +407,18 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
410407
def _share_information_to_prevent_deadlock(self):
411408
self._share_pids()
412409

413-
# remove `PL_DDP_SYNC_TMPDIR` from os.environ
414-
self._sync_dir = os.environ.pop("PL_DDP_SYNC_TMPDIR", None)
410+
# there should be a unique sync_dir per nodes.
411+
if self.local_rank == 0:
412+
# create a temporary directory used to synchronize processes on deadlock.
413+
self._sync_dir = tempfile.mkdtemp()
414+
415+
sync_dirs = []
416+
global_node_rank_zero = 0
417+
for _ in range(self.num_nodes):
418+
sync_dirs.append(self.broadcast(self._sync_dir, global_node_rank_zero))
419+
global_node_rank_zero += self.world_size // self.num_nodes
420+
421+
self._sync_dir = sync_dirs[self.node_rank]
415422

416423
def _share_pids(self):
417424
"""
@@ -436,11 +443,11 @@ def reconciliate_processes(self, trace: str):
436443

437444
# return if all processes wrote a file in the `sync_dir`.
438445
# todo (tchaton) Add support for non-shared file-system which will fail.
439-
if len(os.listdir(sync_dir)) == self.world_size:
446+
if len(os.listdir(sync_dir)) == (self.world_size // self.num_nodes):
440447
return
441448

442449
for pid in self._pids:
443450
if pid != os.getpid():
444451
os.kill(pid, signal.SIGKILL)
445-
shutil.rmtree(sync_dir)
446-
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
452+
shutil.rmtree(sync_dir)
453+
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
import sys
3+
from contextlib import suppress
4+
5+
from pytorch_lightning import Trainer
6+
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
7+
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
8+
from tests.helpers.boring_model import BoringModel
9+
10+
if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1":
11+
12+
class CustomException(Exception):
13+
pass
14+
15+
class Model(BoringModel):
16+
def training_step(self, batch, batch_idx):
17+
if batch_idx == 1 and self.trainer.is_global_zero:
18+
# rank 0: raises an exception
19+
# rank 1: continues training but will hang on the next barrier in the training loop
20+
raise CustomException
21+
return super().training_step(batch, batch_idx)
22+
23+
model = Model()
24+
25+
trainer = Trainer(
26+
default_root_dir=".", max_epochs=1, limit_train_batches=5, num_sanity_val_steps=0, gpus=2, accelerator="ddp"
27+
)
28+
assert isinstance(trainer.training_type_plugin, DDPPlugin)
29+
30+
with suppress(DeadlockDetectedException):
31+
# simulate random failure in training_step on rank 0
32+
trainer.fit(model)
33+
34+
# used to capture success from this script in the CI.
35+
print("SUCCEEDED")
36+
37+
sys.exit(0)

tests/special_tests.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ if [ $? -eq 0 ]; then
7878
report+="Ran\ttests/utilities/test_warnings.py\n"
7979
fi
8080

81+
# TODO: enable when CI uses torch>=1.9
82+
# test deadlock is properly handled with TorchElastic.
83+
# LOGS=$(PL_RUNNING_SPECIAL_TESTS=1 python -m torch.distributed.run --nproc_per_node=2 --max_restarts 0 -m coverage run --source pytorch_lightning -a tests/plugins/environments/torch_elastic_deadlock.py | grep "SUCCEEDED")
84+
# if [ -z "$LOGS" ]; then
85+
# exit 1
86+
# fi
87+
# report+="Ran\ttests/plugins/environments/torch_elastic_deadlock.py\n"
88+
8189
# test that a user can manually launch individual processes
8290
args="--trainer.gpus 2 --trainer.accelerator ddp --trainer.fast_dev_run 1"
8391
MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python pl_examples/basic_examples/simple_image_classifier.py ${args} &

0 commit comments

Comments
 (0)