@@ -260,6 +260,9 @@ def __init__(
260260 self .best_model_path = ""
261261 self .last_model_path = ""
262262 self ._last_checkpoint_saved = ""
263+ # When using step/time-based checkpointing with a validation-only monitored metric,
264+ # defer the save until validation has produced the metric
265+ self ._defer_save_until_validation : bool = False
263266
264267 self .kth_value : Tensor
265268 self .dirpath : Optional [_PATH ]
@@ -306,14 +309,17 @@ def on_train_batch_end(
306309 batch_idx : int ,
307310 ) -> None :
308311 """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
309- if self ._should_skip_saving_checkpoint (trainer ):
310- return
312+ # Do not return early here because we may need to set deferral flags even
313+ # if a save already happened at this global step. We'll enforce the skip
314+ # just before actually saving below.
315+ skip_due_to_state = self ._should_skip_saving_checkpoint (trainer )
311316 skip_batch = self ._every_n_train_steps < 1 or (trainer .global_step % self ._every_n_train_steps != 0 )
312317
313318 train_time_interval = self ._train_time_interval
314319 skip_time = True
315320 now = time .monotonic ()
316- if train_time_interval :
321+ # Important: allow zero timedelta as a valid interval
322+ if train_time_interval is not None :
317323 prev_time_check = self ._last_time_checked
318324 skip_time = prev_time_check is None or (now - prev_time_check ) < train_time_interval .total_seconds ()
319325 # in case we have time differences across ranks
@@ -326,6 +332,42 @@ def on_train_batch_end(
326332 self ._last_time_checked = now
327333
328334 monitor_candidates = self ._monitor_candidates (trainer )
335+ # If monitoring a metric that is not yet available (e.g., validation-only),
336+ # defer saving until validation end so the metric is present.
337+ if self .monitor is not None and self .monitor not in monitor_candidates :
338+ # Defer both top-k and last to avoid blocking with `_last_global_step_saved`
339+ self ._defer_save_until_validation = True
340+ return
341+
342+ # Even if the monitored key exists, it could be stale from a previous validation.
343+ # If validation is scheduled to run right after this batch (e.g., last batch of epoch)
344+ # and we are not saving at train epoch end, defer to `on_validation_end` to use fresh metrics.
345+ if (
346+ self .monitor is not None
347+ and not self ._should_save_on_train_epoch_end (trainer )
348+ and getattr (trainer .fit_loop .epoch_loop .batch_progress , "is_last_batch" , False )
349+ ):
350+ # Only defer if a validation loop is expected to run after this batch.
351+ will_run_val = False
352+ if getattr (trainer , "enable_validation" , False ):
353+ num_val_batches = (
354+ sum (trainer .num_val_batches )
355+ if isinstance (trainer .num_val_batches , list )
356+ else trainer .num_val_batches
357+ )
358+ if num_val_batches and num_val_batches > 0 :
359+ cve = trainer .check_val_every_n_epoch
360+ if cve is None or ((trainer .current_epoch + 1 ) % cve == 0 ):
361+ will_run_val = True
362+
363+ if will_run_val :
364+ self ._defer_save_until_validation = True
365+ return
366+
367+ # Only proceed to save if not skipping due to trainer/callback state
368+ if skip_due_to_state :
369+ return
370+
329371 self ._save_topk_checkpoint (trainer , monitor_candidates )
330372 self ._save_last_checkpoint (trainer , monitor_candidates )
331373
@@ -343,6 +385,14 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
343385 """Save a checkpoint at the end of the validation stage."""
344386 if not self ._should_skip_saving_checkpoint (trainer ) and not self ._should_save_on_train_epoch_end (trainer ):
345387 monitor_candidates = self ._monitor_candidates (trainer )
388+ # If a step/time-triggered save was deferred due to a missing monitored metric,
389+ # perform the save now that validation metrics are available.
390+ if self ._defer_save_until_validation :
391+ self ._save_topk_checkpoint (trainer , monitor_candidates )
392+ self ._save_last_checkpoint (trainer , monitor_candidates )
393+ self ._defer_save_until_validation = False
394+ return
395+
346396 if self ._every_n_epochs >= 1 and (trainer .current_epoch + 1 ) % self ._every_n_epochs == 0 :
347397 self ._save_topk_checkpoint (trainer , monitor_candidates )
348398 self ._save_last_checkpoint (trainer , monitor_candidates )
0 commit comments