Skip to content

Commit 989b759

Browse files
KAVYANSHTYAGIpre-commit-ci[bot]Borda
authored
Fix: Synchronize SIGTERM Handling in DDP to Prevent Deadlocks (Lightning-AI#20825)
* Update signal_connector.py * Update training_epoch_loop.py * Create test_ddp_sigterm_handling.py * update + chlog * Apply suggestions from code review --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka B <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 7b8ff1d commit 989b759

File tree

5 files changed

+119
-4
lines changed

5 files changed

+119
-4
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828
- Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/issues/20692))
2929

3030

31+
- Fix: Synchronize SIGTERM Handling in DDP to Prevent Deadlocks ([#20825](https://github.com/Lightning-AI/pytorch-lightning/pull/20825))
32+
33+
3134
---
3235

3336
## [2.5.1] - 2025-03-18

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import contextlib
1415
import math
1516
from collections import OrderedDict
1617
from dataclasses import dataclass
1718
from typing import Any, Optional, Union
1819

20+
import torch
1921
from typing_extensions import override
2022

2123
import lightning.pytorch as pl
@@ -249,6 +251,21 @@ def _on_before_fetch(self) -> None:
249251
def _on_after_fetch(self) -> None:
250252
self.trainer.profiler.stop(f"[{self.__class__.__name__}].train_dataloader_next")
251253

254+
def _broadcast_sigterm_tensor(self) -> None:
255+
try:
256+
sigterm_tensor = torch.tensor(
257+
[1 if getattr(self.trainer, "received_sigterm", False) else 0],
258+
device=self.trainer.strategy.root_device,
259+
)
260+
torch.distributed.broadcast(sigterm_tensor, src=0)
261+
except Exception:
262+
sigterm_tensor = torch.tensor([0], device=self.trainer.strategy.root_device)
263+
264+
if sigterm_tensor.item() == 1:
265+
with contextlib.suppress(Exception):
266+
torch.distributed.barrier() # prevent deadlocks by syncing all ranks before exit
267+
raise SIGTERMException()
268+
252269
def advance(self, data_fetcher: _DataFetcher) -> None:
253270
"""Runs a single training batch.
254271
@@ -272,6 +289,13 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
272289
# we are going to train first so the val loop does not need to restart
273290
self.val_loop.restarting = False
274291

292+
# =====================================================================
293+
294+
if torch.distributed.is_available() and torch.distributed.is_initialized() and self.trainer.world_size > 1:
295+
self._broadcast_sigterm_tensor()
296+
297+
# =====================================================================
298+
275299
if using_dataloader_iter := isinstance(data_fetcher, _DataLoaderIterDataFetcher):
276300
dataloader_iter = next(data_fetcher)
277301
# hook's batch_idx and dataloader_idx arguments correctness cannot be guaranteed in this setting

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,9 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None:
106106
model_checkpoint = LitModelCheckpoint(model_registry=self.trainer._model_registry)
107107
else:
108108
rank_zero_info(
109-
"Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable"
110-
" `LitModelCheckpoint` for automatic upload to the Lightning model registry."
109+
"💡 Tip: For seamless cloud uploads and versioning,"
110+
" try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint,"
111+
" which syncs automatically with the Lightning model registry."
111112
)
112113
model_checkpoint = ModelCheckpoint()
113114
self.trainer.callbacks.append(model_checkpoint)

src/lightning/pytorch/trainer/connectors/signal_connector.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from types import FrameType
88
from typing import Any, Callable, Union
99

10+
import torch
11+
import torch.distributed as dist
12+
1013
import lightning.pytorch as pl
1114
from lightning.fabric.plugins.environments import SLURMEnvironment
1215
from lightning.fabric.utilities.imports import _IS_WINDOWS
@@ -104,12 +107,16 @@ def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
104107

105108
def _sigterm_notifier_fn(self, signum: _SIGNUM, _: FrameType) -> None:
106109
log.info(rank_prefixed_message(f"Received SIGTERM: {signum}", self.trainer.local_rank))
107-
# subprocesses killing the parent process is not supported, only the parent (rank 0) does it
108110
if not self.received_sigterm:
109-
# send the same signal to the subprocesses
110111
launcher = self.trainer.strategy.launcher
111112
if launcher is not None:
112113
launcher.kill(signum)
114+
115+
# New broadcast logic
116+
if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1:
117+
sigterm_tensor = torch.tensor([1], device=self.trainer.strategy.root_device)
118+
dist.broadcast(sigterm_tensor, src=0)
119+
113120
self.received_sigterm = True
114121

115122
def _sigterm_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import os
2+
import signal
3+
import time
4+
5+
import pytest
6+
import torch
7+
import torch.multiprocessing as mp
8+
9+
from lightning.pytorch import LightningModule, Trainer, seed_everything
10+
from lightning.pytorch.demos.boring_classes import BoringDataModule
11+
from lightning.pytorch.strategies.ddp import DDPStrategy
12+
from lightning.pytorch.utilities.exceptions import SIGTERMException
13+
14+
# Skip the test if DDP or multiple devices are not available
15+
16+
pytestmark = pytest.mark.skipif(
17+
not torch.distributed.is_available() or torch.cuda.device_count() < 2,
18+
reason="Requires torch.distributed and at least 2 CUDA devices",
19+
)
20+
21+
22+
class DummyModel(LightningModule):
23+
def training_step(self, batch, batch_idx):
24+
# Simulate SIGTERM in rank 0 at batch 2
25+
if self.trainer.global_rank == 0 and batch_idx == 2:
26+
time.sleep(3) # Let other ranks proceed to the next batch
27+
os.kill(os.getpid(), signal.SIGTERM)
28+
return super().training_step(batch, batch_idx)
29+
30+
31+
def run_ddp_sigterm(rank, world_size, tmpdir):
32+
os.environ["MASTER_ADDR"] = "localhost"
33+
os.environ["MASTER_PORT"] = "12355"
34+
os.environ["RANK"] = str(rank)
35+
os.environ["WORLD_SIZE"] = str(world_size)
36+
37+
seed_everything(42)
38+
39+
torch.cuda.set_device(rank) if torch.cuda.is_available() else None
40+
41+
model = DummyModel()
42+
datamodule = BoringDataModule()
43+
44+
trainer = Trainer(
45+
accelerator="cuda" if torch.cuda.is_available() else "cpu",
46+
strategy=DDPStrategy(find_unused_parameters=False),
47+
devices=world_size,
48+
num_nodes=1,
49+
max_epochs=3,
50+
default_root_dir=tmpdir,
51+
enable_checkpointing=False,
52+
enable_progress_bar=False,
53+
enable_model_summary=False,
54+
logger=False,
55+
)
56+
57+
try:
58+
trainer.fit(model, datamodule=datamodule)
59+
except SIGTERMException:
60+
# Test passed: SIGTERM was properly raised and caught
61+
print(f"[Rank {rank}] Caught SIGTERMException successfully.")
62+
except Exception as e:
63+
pytest.fail(f"[Rank {rank}] Unexpected exception: {e}")
64+
65+
66+
def test_ddp_sigterm_handling(tmp_path):
67+
world_size = 2
68+
mp.spawn(run_ddp_sigterm, args=(world_size, tmp_path), nprocs=world_size, join=True)
69+
70+
71+
@pytest.mark.skipif(
72+
not torch.distributed.is_available(),
73+
reason="Requires torch.distributed",
74+
)
75+
@pytest.mark.skipif(
76+
torch.cuda.is_available() and torch.cuda.device_count() < 2,
77+
reason="Requires >=2 CUDA devices or use CPU",
78+
)
79+
def test_sigterm_handling_ddp(tmp_path):
80+
test_ddp_sigterm_handling(tmp_path)

0 commit comments

Comments
 (0)