Skip to content

Commit 6a1bbf1

Browse files
Update signal_connector.py
1 parent 01ba7a1 commit 6a1bbf1

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

src/lightning/pytorch/trainer/connectors/signal_connector.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from subprocess import call
77
from types import FrameType
88
from typing import Any, Callable, Union
9+
import torch
10+
import torch.distributed as dist
911

1012
import lightning.pytorch as pl
1113
from lightning.fabric.plugins.environments import SLURMEnvironment
@@ -104,14 +106,19 @@ def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
104106

105107
def _sigterm_notifier_fn(self, signum: _SIGNUM, _: FrameType) -> None:
106108
log.info(rank_prefixed_message(f"Received SIGTERM: {signum}", self.trainer.local_rank))
107-
# subprocesses killing the parent process is not supported, only the parent (rank 0) does it
108109
if not self.received_sigterm:
109-
# send the same signal to the subprocesses
110110
launcher = self.trainer.strategy.launcher
111111
if launcher is not None:
112112
launcher.kill(signum)
113+
114+
# New broadcast logic
115+
if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1:
116+
sigterm_tensor = torch.tensor([1], device=self.trainer.strategy.root_device)
117+
dist.broadcast(sigterm_tensor, src=0)
118+
113119
self.received_sigterm = True
114120

121+
115122
def _sigterm_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
116123
log.info(f"Bypassing SIGTERM: {signum}")
117124

0 commit comments

Comments
 (0)