Skip to content

Commit 71189de

Browse files
Update training_epoch_loop.py
1 parent 22dc0ab commit 71189de

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,20 +276,29 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
276276
self.val_loop.restarting = False
277277

278278
# =====================================================================
279-
# FINAL: Check for SIGTERM broadcast and exit synchronously across ranks
280279
from lightning.pytorch.utilities.exceptions import SIGTERMException
281-
280+
282281
if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1:
283-
with suppress(Exception): # never crash CI
282+
try:
283+
# Prepare the SIGTERM signal tensor (1 if signal received, else 0)
284284
sigterm_tensor = torch.tensor(
285285
[1 if getattr(self.trainer, "received_sigterm", False) else 0],
286286
device=self.trainer.strategy.root_device,
287287
)
288+
# Broadcast the SIGTERM flag from rank 0 to all other ranks
288289
dist.broadcast(sigterm_tensor, src=0)
290+
except Exception:
291+
# In case broadcast fails (e.g., CPU-only or non-DDP), fallback to no SIGTERM
292+
sigterm_tensor = torch.tensor([0], device=self.trainer.strategy.root_device)
293+
294+
# If SIGTERM flag is set, synchronize all ranks and raise exception to exit
295+
if sigterm_tensor.item() == 1:
296+
try:
297+
dist.barrier() # prevent deadlocks by syncing all ranks before exit
298+
except Exception:
299+
pass # Don't fail if barrier fails in fallback mode
300+
raise SIGTERMException()
289301

290-
if sigterm_tensor.item() == 1:
291-
dist.barrier()
292-
raise SIGTERMException()
293302
# =====================================================================
294303

295304
if using_dataloader_iter := isinstance(data_fetcher, _DataLoaderIterDataFetcher):

0 commit comments

Comments
 (0)