@@ -262,6 +262,9 @@ def __init__(
262262 self .best_model_path = ""
263263 self .last_model_path = ""
264264 self ._last_checkpoint_saved = ""
265+ # When using step/time-based checkpointing with a validation-only monitored metric,
266+ # defer the save until validation has produced the metric
267+ self ._defer_save_until_validation : bool = False
265268
266269 self .kth_value : Tensor
267270 self .dirpath : Optional [_PATH ]
@@ -308,14 +311,17 @@ def on_train_batch_end(
308311 batch_idx : int ,
309312 ) -> None :
310313 """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
311- if self ._should_skip_saving_checkpoint (trainer ):
312- return
314+ # Do not return early here because we may need to set deferral flags even
315+ # if a save already happened at this global step. We'll enforce the skip
316+ # just before actually saving below.
317+ skip_due_to_state = self ._should_skip_saving_checkpoint (trainer )
313318 skip_batch = self ._every_n_train_steps < 1 or (trainer .global_step % self ._every_n_train_steps != 0 )
314319
315320 train_time_interval = self ._train_time_interval
316321 skip_time = True
317322 now = time .monotonic ()
318- if train_time_interval :
323+ # Important: allow zero timedelta as a valid interval
324+ if train_time_interval is not None :
319325 prev_time_check = self ._last_time_checked
320326 skip_time = prev_time_check is None or (now - prev_time_check ) < train_time_interval .total_seconds ()
321327 # in case we have time differences across ranks
@@ -328,6 +334,42 @@ def on_train_batch_end(
328334 self ._last_time_checked = now
329335
330336 monitor_candidates = self ._monitor_candidates (trainer )
337+ # If monitoring a metric that is not yet available (e.g., validation-only),
338+ # defer saving until validation end so the metric is present.
339+ if self .monitor is not None and self .monitor not in monitor_candidates :
340+ # Defer both top-k and last to avoid blocking with `_last_global_step_saved`
341+ self ._defer_save_until_validation = True
342+ return
343+
344+ # Even if the monitored key exists, it could be stale from a previous validation.
345+ # If validation is scheduled to run right after this batch (e.g., last batch of epoch)
346+ # and we are not saving at train epoch end, defer to `on_validation_end` to use fresh metrics.
347+ if (
348+ self .monitor is not None
349+ and not self ._should_save_on_train_epoch_end (trainer )
350+ and getattr (trainer .fit_loop .epoch_loop .batch_progress , "is_last_batch" , False )
351+ ):
352+ # Only defer if a validation loop is expected to run after this batch.
353+ will_run_val = False
354+ if getattr (trainer , "enable_validation" , False ):
355+ num_val_batches = (
356+ sum (trainer .num_val_batches )
357+ if isinstance (trainer .num_val_batches , list )
358+ else trainer .num_val_batches
359+ )
360+ if num_val_batches and num_val_batches > 0 :
361+ cve = trainer .check_val_every_n_epoch
362+ if cve is None or ((trainer .current_epoch + 1 ) % cve == 0 ):
363+ will_run_val = True
364+
365+ if will_run_val :
366+ self ._defer_save_until_validation = True
367+ return
368+
369+ # Only proceed to save if not skipping due to trainer/callback state
370+ if skip_due_to_state :
371+ return
372+
331373 self ._save_topk_checkpoint (trainer , monitor_candidates )
332374 self ._save_last_checkpoint (trainer , monitor_candidates )
333375
@@ -345,6 +387,14 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
345387 """Save a checkpoint at the end of the validation stage."""
346388 if not self ._should_skip_saving_checkpoint (trainer ) and not self ._should_save_on_train_epoch_end (trainer ):
347389 monitor_candidates = self ._monitor_candidates (trainer )
390+ # If a step/time-triggered save was deferred due to a missing monitored metric,
391+ # perform the save now that validation metrics are available.
392+ if self ._defer_save_until_validation :
393+ self ._save_topk_checkpoint (trainer , monitor_candidates )
394+ self ._save_last_checkpoint (trainer , monitor_candidates )
395+ self ._defer_save_until_validation = False
396+ return
397+
348398 if self ._every_n_epochs >= 1 and (trainer .current_epoch + 1 ) % self ._every_n_epochs == 0 :
349399 self ._save_topk_checkpoint (trainer , monitor_candidates )
350400 self ._save_last_checkpoint (trainer , monitor_candidates )
0 commit comments