@@ -260,6 +260,9 @@ def __init__(
260
260
self .best_model_path = ""
261
261
self .last_model_path = ""
262
262
self ._last_checkpoint_saved = ""
263
+ # When using step/time-based checkpointing with a validation-only monitored metric,
264
+ # defer the save until validation has produced the metric
265
+ self ._defer_save_until_validation : bool = False
263
266
264
267
self .kth_value : Tensor
265
268
self .dirpath : Optional [_PATH ]
@@ -306,14 +309,17 @@ def on_train_batch_end(
306
309
batch_idx : int ,
307
310
) -> None :
308
311
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
309
- if self ._should_skip_saving_checkpoint (trainer ):
310
- return
312
+ # Do not return early here because we may need to set deferral flags even
313
+ # if a save already happened at this global step. We'll enforce the skip
314
+ # just before actually saving below.
315
+ skip_due_to_state = self ._should_skip_saving_checkpoint (trainer )
311
316
skip_batch = self ._every_n_train_steps < 1 or (trainer .global_step % self ._every_n_train_steps != 0 )
312
317
313
318
train_time_interval = self ._train_time_interval
314
319
skip_time = True
315
320
now = time .monotonic ()
316
- if train_time_interval :
321
+ # Important: allow zero timedelta as a valid interval
322
+ if train_time_interval is not None :
317
323
prev_time_check = self ._last_time_checked
318
324
skip_time = prev_time_check is None or (now - prev_time_check ) < train_time_interval .total_seconds ()
319
325
# in case we have time differences across ranks
@@ -326,6 +332,42 @@ def on_train_batch_end(
326
332
self ._last_time_checked = now
327
333
328
334
monitor_candidates = self ._monitor_candidates (trainer )
335
+ # If monitoring a metric that is not yet available (e.g., validation-only),
336
+ # defer saving until validation end so the metric is present.
337
+ if self .monitor is not None and self .monitor not in monitor_candidates :
338
+ # Defer both top-k and last to avoid blocking with `_last_global_step_saved`
339
+ self ._defer_save_until_validation = True
340
+ return
341
+
342
+ # Even if the monitored key exists, it could be stale from a previous validation.
343
+ # If validation is scheduled to run right after this batch (e.g., last batch of epoch)
344
+ # and we are not saving at train epoch end, defer to `on_validation_end` to use fresh metrics.
345
+ if (
346
+ self .monitor is not None
347
+ and not self ._should_save_on_train_epoch_end (trainer )
348
+ and getattr (trainer .fit_loop .epoch_loop .batch_progress , "is_last_batch" , False )
349
+ ):
350
+ # Only defer if a validation loop is expected to run after this batch.
351
+ will_run_val = False
352
+ if getattr (trainer , "enable_validation" , False ):
353
+ num_val_batches = (
354
+ sum (trainer .num_val_batches )
355
+ if isinstance (trainer .num_val_batches , list )
356
+ else trainer .num_val_batches
357
+ )
358
+ if num_val_batches and num_val_batches > 0 :
359
+ cve = trainer .check_val_every_n_epoch
360
+ if cve is None or ((trainer .current_epoch + 1 ) % cve == 0 ):
361
+ will_run_val = True
362
+
363
+ if will_run_val :
364
+ self ._defer_save_until_validation = True
365
+ return
366
+
367
+ # Only proceed to save if not skipping due to trainer/callback state
368
+ if skip_due_to_state :
369
+ return
370
+
329
371
self ._save_topk_checkpoint (trainer , monitor_candidates )
330
372
self ._save_last_checkpoint (trainer , monitor_candidates )
331
373
@@ -343,6 +385,14 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
343
385
"""Save a checkpoint at the end of the validation stage."""
344
386
if not self ._should_skip_saving_checkpoint (trainer ) and not self ._should_save_on_train_epoch_end (trainer ):
345
387
monitor_candidates = self ._monitor_candidates (trainer )
388
+ # If a step/time-triggered save was deferred due to a missing monitored metric,
389
+ # perform the save now that validation metrics are available.
390
+ if self ._defer_save_until_validation :
391
+ self ._save_topk_checkpoint (trainer , monitor_candidates )
392
+ self ._save_last_checkpoint (trainer , monitor_candidates )
393
+ self ._defer_save_until_validation = False
394
+ return
395
+
346
396
if self ._every_n_epochs >= 1 and (trainer .current_epoch + 1 ) % self ._every_n_epochs == 0 :
347
397
self ._save_topk_checkpoint (trainer , monitor_candidates )
348
398
self ._save_last_checkpoint (trainer , monitor_candidates )
0 commit comments