@@ -275,34 +275,42 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
275275 self .val_loop .restarting = False
276276
277277 # =====================================================================
278- # NEW : Check for SIGTERM broadcast and exit synchronously across ranks
278+ # FINAL : Check for SIGTERM broadcast and exit synchronously across ranks
279279 from lightning .pytorch .utilities .exceptions import SIGTERMException
280-
280+
281+ # Rank 0 broadcasts SIGTERM status
281282 if (
282283 dist .is_available ()
283284 and dist .is_initialized ()
284285 and getattr (self .trainer .strategy , "global_rank" , 0 ) == 0
285286 and self .trainer .world_size > 1
286287 ):
287288 try :
288- sigterm_tensor = torch .tensor ([0 ], device = self .trainer .strategy .root_device )
289+ sigterm_tensor = torch .tensor (
290+ [1 if self .trainer .received_sigterm else 0 ],
291+ device = self .trainer .strategy .root_device ,
292+ )
289293 dist .broadcast (sigterm_tensor , src = 0 )
290294 except Exception :
291- # log or pass silently to avoid crashing tests on CPU CI
292- pass
293-
294- if dist .is_available () and dist .is_initialized () and self .trainer .world_size > 1 :
295- sigterm_tensor = torch .tensor ([0 ], device = self .trainer .strategy .root_device )
295+ pass # Ignore broadcast error on non-DDP setups
296+
297+ # All ranks listen for SIGTERM
298+ if (
299+ dist .is_available ()
300+ and dist .is_initialized ()
301+ and self .trainer .world_size > 1
302+ ):
296303 try :
304+ sigterm_tensor = torch .tensor ([0 ], device = self .trainer .strategy .root_device )
297305 dist .broadcast (sigterm_tensor , src = 0 )
298306 if sigterm_tensor .item () == 1 :
299307 dist .barrier ()
300308 raise SIGTERMException ()
301309 except Exception :
302- # Fallback safety: log and skip gracefully
303- pass
310+ pass # Fallback for CPU/CI environments
304311 # =====================================================================
305312
313+
306314 if using_dataloader_iter := isinstance (data_fetcher , _DataLoaderIterDataFetcher ):
307315 dataloader_iter = next (data_fetcher )
308316 # hook's batch_idx and dataloader_idx arguments correctness cannot be guaranteed in this setting
0 commit comments