Skip to content
Merged
3 changes: 2 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106))



---
Expand Down
56 changes: 53 additions & 3 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def __init__(
self.best_model_path = ""
self.last_model_path = ""
self._last_checkpoint_saved = ""
# When using step/time-based checkpointing with a validation-only monitored metric,
# defer the save until validation has produced the metric
self._defer_save_until_validation: bool = False

self.kth_value: Tensor
self.dirpath: Optional[_PATH]
Expand Down Expand Up @@ -306,14 +309,17 @@ def on_train_batch_end(
batch_idx: int,
) -> None:
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
if self._should_skip_saving_checkpoint(trainer):
return
# Do not return early here because we may need to set deferral flags even
# if a save already happened at this global step. We'll enforce the skip
# just before actually saving below.
skip_due_to_state = self._should_skip_saving_checkpoint(trainer)
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)

train_time_interval = self._train_time_interval
skip_time = True
now = time.monotonic()
if train_time_interval:
# Important: allow zero timedelta as a valid interval
if train_time_interval is not None:
prev_time_check = self._last_time_checked
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
# in case we have time differences across ranks
Expand All @@ -326,6 +332,42 @@ def on_train_batch_end(
self._last_time_checked = now

monitor_candidates = self._monitor_candidates(trainer)
# If monitoring a metric that is not yet available (e.g., validation-only),
# defer saving until validation end so the metric is present.
if self.monitor is not None and self.monitor not in monitor_candidates:
# Defer both top-k and last to avoid blocking with `_last_global_step_saved`
self._defer_save_until_validation = True
return

# Even if the monitored key exists, it could be stale from a previous validation.
# If validation is scheduled to run right after this batch (e.g., last batch of epoch)
# and we are not saving at train epoch end, defer to `on_validation_end` to use fresh metrics.
if (
self.monitor is not None
and not self._should_save_on_train_epoch_end(trainer)
and getattr(trainer.fit_loop.epoch_loop.batch_progress, "is_last_batch", False)
):
# Only defer if a validation loop is expected to run after this batch.
will_run_val = False
if getattr(trainer, "enable_validation", False):
num_val_batches = (
sum(trainer.num_val_batches)
if isinstance(trainer.num_val_batches, list)
else trainer.num_val_batches
)
if num_val_batches and num_val_batches > 0:
cve = trainer.check_val_every_n_epoch
if cve is None or ((trainer.current_epoch + 1) % cve == 0):
will_run_val = True

if will_run_val:
self._defer_save_until_validation = True
return

# Only proceed to save if not skipping due to trainer/callback state
if skip_due_to_state:
return

self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)

Expand All @@ -343,6 +385,14 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
"""Save a checkpoint at the end of the validation stage."""
if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
# If a step/time-triggered save was deferred due to a missing monitored metric,
# perform the save now that validation metrics are available.
if self._defer_save_until_validation:
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
self._defer_save_until_validation = False
return

if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import math
from datetime import timedelta

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch import LightningModule, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint


class TinyDataset(Dataset):
def __init__(self, n: int = 4):
self.x = torch.arange(n, dtype=torch.float32).view(-1, 1)
self.y = self.x.clone()

def __len__(self):
return len(self.x)

def __getitem__(self, idx):
return self.x[idx], self.y[idx]


class TrainMetricModule(LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(1, 1)
self._counter = 0.0

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.layer(x)
loss = F.mse_loss(y_hat, y)
# strictly increasing train metric per step
self._counter += 1.0
self.log("train_score", torch.tensor(self._counter), on_step=True, on_epoch=False, prog_bar=False, logger=True)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
pass

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.01)


def _make_loaders(n=4):
ds = TinyDataset(n=n)
train_loader = DataLoader(ds, batch_size=2, shuffle=False)
val_loader = DataLoader(ds, batch_size=2, shuffle=False)
return train_loader, val_loader


def test_model_checkpoint_every_n_train_steps_with_train_metric_saves_at_step(tmp_path):
"""When monitoring a train-step metric, step-interval checkpointing should save at the step boundary (no deferral)
and best_model_score should match the last train metric value."""
seed_everything(123)

train_loader, val_loader = _make_loaders(n=4)
model = TrainMetricModule()

ckpt = ModelCheckpoint(
dirpath=tmp_path,
monitor="train_score",
mode="max",
save_top_k=1,
every_n_train_steps=1,
train_time_interval=None,
every_n_epochs=0,
save_on_train_epoch_end=False,
save_weights_only=True,
)

# 2 batches/epoch, run 2 epochs to have multiple step saves
trainer = Trainer(
max_epochs=2,
accelerator="cpu",
devices=1,
callbacks=[ckpt],
num_sanity_val_steps=0,
log_every_n_steps=1,
limit_train_batches=2,
limit_val_batches=0, # no validation needed for this test
enable_checkpointing=True,
enable_model_summary=False,
logger=False,
)

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

assert ckpt.best_model_score is not None
# 2 epochs * 2 steps/epoch = 4 steps total; metric increments by 1 each step
expected = 4.0
actual = float(ckpt.best_model_score)
assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)


@pytest.mark.parametrize("val_scores", [[0.2, 0.4, 0.9]])
def test_model_checkpoint_time_interval_with_val_metric_defers_until_validation(tmp_path, val_scores):
"""With time-interval-based checkpointing, and a validation-only metric, ensure we don't save using stale metrics
at step boundaries; saving should occur at validation end."""
seed_everything(123)

train_loader, val_loader = _make_loaders(n=4)

model = ValMetricModule(val_scores=val_scores)

ckpt = ModelCheckpoint(
dirpath=tmp_path,
monitor="auroc",
mode="max",
save_top_k=1,
every_n_train_steps=0, # disable step-based
train_time_interval=timedelta(seconds=0), # trigger as often as possible
every_n_epochs=0,
save_on_train_epoch_end=False,
save_weights_only=True,
)

trainer = Trainer(
max_epochs=len(val_scores),
accelerator="cpu",
devices=1,
callbacks=[ckpt],
num_sanity_val_steps=0,
log_every_n_steps=1,
limit_train_batches=2,
limit_val_batches=1,
enable_checkpointing=True,
enable_model_summary=False,
logger=False,
)

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

assert ckpt.best_model_score is not None
expected = max(val_scores)
actual = float(ckpt.best_model_score)
assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)


class ValMetricModule(LightningModule):
def __init__(self, val_scores: list[float]):
super().__init__()
self.layer = nn.Linear(1, 1)
self._val_scores = [float(s) for s in val_scores]

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.layer(x)
loss = F.mse_loss(y_hat, y)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
pass

def on_validation_epoch_end(self):
score = self._val_scores[self.current_epoch]
self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True)

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.01)


@pytest.mark.parametrize("val_scores", [[0.1, 0.5, 1.0, 3.0]])
def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tmp_path, val_scores):
"""With validation running every 2 epochs, step-triggered saves at the end of non-validation epochs should be
deferred and then performed at the next validation end when the metric is available."""
seed_everything(123)

train_loader, val_loader = _make_loaders(n=4)

model = ValMetricModule(val_scores=val_scores)

ckpt = ModelCheckpoint(
dirpath=tmp_path,
monitor="auroc",
mode="max",
save_top_k=1,
every_n_train_steps=2, # end of each epoch
train_time_interval=None,
every_n_epochs=0,
save_on_train_epoch_end=False,
save_weights_only=True,
)

trainer = Trainer(
max_epochs=len(val_scores),
accelerator="cpu",
devices=1,
callbacks=[ckpt],
num_sanity_val_steps=0,
log_every_n_steps=1,
limit_train_batches=2,
limit_val_batches=1,
enable_checkpointing=True,
enable_model_summary=False,
logger=False,
check_val_every_n_epoch=2, # only validate every 2 epochs
)

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

assert ckpt.best_model_score is not None
expected = max(val_scores) # last/maximum value occurs at final validation epoch
actual = float(ckpt.best_model_score)
assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)
Loading
Loading