Skip to content

Commit f50b3a9

Browse files
Update training_epoch_loop.py
1 parent f327aa7 commit f50b3a9

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -275,34 +275,42 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
275275
self.val_loop.restarting = False
276276

277277
# =====================================================================
278-
# NEW: Check for SIGTERM broadcast and exit synchronously across ranks
278+
# FINAL: Check for SIGTERM broadcast and exit synchronously across ranks
279279
from lightning.pytorch.utilities.exceptions import SIGTERMException
280-
280+
281+
# Rank 0 broadcasts SIGTERM status
281282
if (
282283
dist.is_available()
283284
and dist.is_initialized()
284285
and getattr(self.trainer.strategy, "global_rank", 0) == 0
285286
and self.trainer.world_size > 1
286287
):
287288
try:
288-
sigterm_tensor = torch.tensor([0], device=self.trainer.strategy.root_device)
289+
sigterm_tensor = torch.tensor(
290+
[1 if self.trainer.received_sigterm else 0],
291+
device=self.trainer.strategy.root_device,
292+
)
289293
dist.broadcast(sigterm_tensor, src=0)
290294
except Exception:
291-
# log or pass silently to avoid crashing tests on CPU CI
292-
pass
293-
294-
if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1:
295-
sigterm_tensor = torch.tensor([0], device=self.trainer.strategy.root_device)
295+
pass # Ignore broadcast error on non-DDP setups
296+
297+
# All ranks listen for SIGTERM
298+
if (
299+
dist.is_available()
300+
and dist.is_initialized()
301+
and self.trainer.world_size > 1
302+
):
296303
try:
304+
sigterm_tensor = torch.tensor([0], device=self.trainer.strategy.root_device)
297305
dist.broadcast(sigterm_tensor, src=0)
298306
if sigterm_tensor.item() == 1:
299307
dist.barrier()
300308
raise SIGTERMException()
301309
except Exception:
302-
# Fallback safety: log and skip gracefully
303-
pass
310+
pass # Fallback for CPU/CI environments
304311
# =====================================================================
305312

313+
306314
if using_dataloader_iter := isinstance(data_fetcher, _DataLoaderIterDataFetcher):
307315
dataloader_iter = next(data_fetcher)
308316
# hook's batch_idx and dataloader_idx arguments correctness cannot be guaranteed in this setting

0 commit comments

Comments
 (0)