Skip to content

Commit 2761ad8

Browse files
Update training_epoch_loop.py
1 parent 6a1bbf1 commit 2761ad8

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from dataclasses import dataclass
1717
from typing import Any, Optional, Union
1818

19+
import torch
20+
import torch.distributed as dist
21+
1922
from typing_extensions import override
2023

2124
import lightning.pytorch as pl
@@ -272,6 +275,21 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
272275
# we are going to train first so the val loop does not need to restart
273276
self.val_loop.restarting = False
274277

278+
# =====================================================================
279+
# NEW: Check for SIGTERM broadcast and exit synchronously across ranks
280+
import torch
281+
import torch.distributed as dist
282+
from lightning.pytorch.utilities.exceptions import SIGTERMException
283+
if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1:
284+
# Create a tensor to receive the SIGTERM flag.
285+
sigterm_tensor = torch.tensor([0], device=self.trainer.strategy.root_device)
286+
dist.broadcast(sigterm_tensor, src=0)
287+
if sigterm_tensor.item() == 1:
288+
# synchronize all ranks before exit to prevent deadlock
289+
dist.barrier()
290+
raise SIGTERMException()
291+
# =====================================================================
292+
275293
if using_dataloader_iter := isinstance(data_fetcher, _DataLoaderIterDataFetcher):
276294
dataloader_iter = next(data_fetcher)
277295
# hook's batch_idx and dataloader_idx arguments correctness cannot be guaranteed in this setting
@@ -347,6 +365,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
347365
# -----------------------------------------
348366
trainer._logger_connector.update_train_step_metrics()
349367

368+
350369
def on_advance_end(self, data_fetcher: _DataFetcher) -> None:
351370
# -----------------------------------------
352371
# VALIDATE IF NEEDED

0 commit comments

Comments
 (0)