diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 18ef679312a66..eaca4443ae434 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/issues/20692)) +- Fix: Synchronize SIGTERM Handling in DDP to Prevent Deadlocks ([#20825](https://github.com/Lightning-AI/pytorch-lightning/pull/20825)) + + --- ## [2.5.1] - 2025-03-18 diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 7cdf7888bbfe2..599eccdc8ca91 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import math from collections import OrderedDict from dataclasses import dataclass from typing import Any, Optional, Union +import torch from typing_extensions import override import lightning.pytorch as pl @@ -249,6 +251,21 @@ def _on_before_fetch(self) -> None: def _on_after_fetch(self) -> None: self.trainer.profiler.stop(f"[{self.__class__.__name__}].train_dataloader_next") + def _broadcast_sigterm_tensor(self) -> None: + try: + sigterm_tensor = torch.tensor( + [1 if getattr(self.trainer, "received_sigterm", False) else 0], + device=self.trainer.strategy.root_device, + ) + torch.distributed.broadcast(sigterm_tensor, src=0) + except Exception: + sigterm_tensor = torch.tensor([0], device=self.trainer.strategy.root_device) + + if sigterm_tensor.item() == 1: + with contextlib.suppress(Exception): + torch.distributed.barrier() # prevent deadlocks by syncing all ranks before exit + raise SIGTERMException() + def advance(self, data_fetcher: _DataFetcher) -> None: """Runs a single training batch. @@ -272,6 +289,13 @@ def advance(self, data_fetcher: _DataFetcher) -> None: # we are going to train first so the val loop does not need to restart self.val_loop.restarting = False + # ===================================================================== + + if torch.distributed.is_available() and torch.distributed.is_initialized() and self.trainer.world_size > 1: + self._broadcast_sigterm_tensor() + + # ===================================================================== + if using_dataloader_iter := isinstance(data_fetcher, _DataLoaderIterDataFetcher): dataloader_iter = next(data_fetcher) # hook's batch_idx and dataloader_idx arguments correctness cannot be guaranteed in this setting diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index 8d67081db8638..5c351aeebc564 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -106,8 +106,9 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None: model_checkpoint = LitModelCheckpoint(model_registry=self.trainer._model_registry) else: rank_zero_info( - "Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable" - " `LitModelCheckpoint` for automatic upload to the Lightning model registry." + "💡 Tip: For seamless cloud uploads and versioning," + " try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint," + " which syncs automatically with the Lightning model registry." ) model_checkpoint = ModelCheckpoint() self.trainer.callbacks.append(model_checkpoint) diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index e63fecd3897f2..ece7e902c5f5f 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -7,6 +7,9 @@ from types import FrameType from typing import Any, Callable, Union +import torch +import torch.distributed as dist + import lightning.pytorch as pl from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.fabric.utilities.imports import _IS_WINDOWS @@ -104,12 +107,16 @@ def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None: def _sigterm_notifier_fn(self, signum: _SIGNUM, _: FrameType) -> None: log.info(rank_prefixed_message(f"Received SIGTERM: {signum}", self.trainer.local_rank)) - # subprocesses killing the parent process is not supported, only the parent (rank 0) does it if not self.received_sigterm: - # send the same signal to the subprocesses launcher = self.trainer.strategy.launcher if launcher is not None: launcher.kill(signum) + + # New broadcast logic + if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1: + sigterm_tensor = torch.tensor([1], device=self.trainer.strategy.root_device) + dist.broadcast(sigterm_tensor, src=0) + self.received_sigterm = True def _sigterm_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None: diff --git a/tests/tests_pytorch/trainer/test_ddp_sigterm_handling.py b/tests/tests_pytorch/trainer/test_ddp_sigterm_handling.py new file mode 100644 index 0000000000000..0e4e5210db60c --- /dev/null +++ b/tests/tests_pytorch/trainer/test_ddp_sigterm_handling.py @@ -0,0 +1,80 @@ +import os +import signal +import time + +import pytest +import torch +import torch.multiprocessing as mp + +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.demos.boring_classes import BoringDataModule +from lightning.pytorch.strategies.ddp import DDPStrategy +from lightning.pytorch.utilities.exceptions import SIGTERMException + +# Skip the test if DDP or multiple devices are not available + +pytestmark = pytest.mark.skipif( + not torch.distributed.is_available() or torch.cuda.device_count() < 2, + reason="Requires torch.distributed and at least 2 CUDA devices", +) + + +class DummyModel(LightningModule): + def training_step(self, batch, batch_idx): + # Simulate SIGTERM in rank 0 at batch 2 + if self.trainer.global_rank == 0 and batch_idx == 2: + time.sleep(3) # Let other ranks proceed to the next batch + os.kill(os.getpid(), signal.SIGTERM) + return super().training_step(batch, batch_idx) + + +def run_ddp_sigterm(rank, world_size, tmpdir): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + seed_everything(42) + + torch.cuda.set_device(rank) if torch.cuda.is_available() else None + + model = DummyModel() + datamodule = BoringDataModule() + + trainer = Trainer( + accelerator="cuda" if torch.cuda.is_available() else "cpu", + strategy=DDPStrategy(find_unused_parameters=False), + devices=world_size, + num_nodes=1, + max_epochs=3, + default_root_dir=tmpdir, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + + try: + trainer.fit(model, datamodule=datamodule) + except SIGTERMException: + # Test passed: SIGTERM was properly raised and caught + print(f"[Rank {rank}] Caught SIGTERMException successfully.") + except Exception as e: + pytest.fail(f"[Rank {rank}] Unexpected exception: {e}") + + +def test_ddp_sigterm_handling(tmp_path): + world_size = 2 + mp.spawn(run_ddp_sigterm, args=(world_size, tmp_path), nprocs=world_size, join=True) + + +@pytest.mark.skipif( + not torch.distributed.is_available(), + reason="Requires torch.distributed", +) +@pytest.mark.skipif( + torch.cuda.is_available() and torch.cuda.device_count() < 2, + reason="Requires >=2 CUDA devices or use CPU", +) +def test_sigterm_handling_ddp(tmp_path): + test_ddp_sigterm_handling(tmp_path)