Skip to content

Commit b1cc925

Browse files
littlebullGitBorda
andauthored
fix(callbacks): Defer step/time-triggered ModelCheckpoint saves until validation metrics are available (#21106)
* fix(callbacks): defer step/time-triggered ModelCheckpoint saves until validation metrics are available Root cause: - With `every_n_train_steps` (or `train_time_interval`), checkpoints could save at train batch end before validation ran, so the monitored val metric was missing/stale and `best_model_score` was incorrect. (Refs #20919) Fix: - In [src/lightning/pytorch/callbacks/model_checkpoint.py:ModelCheckpoint.on_train_batch_end]: - Defer saves when the monitored key is missing from [trainer.callback_metrics] - If on the last train batch and not saving at train-epoch-end, defer only when validation will run next: - `trainer.enable_validation` is True - `trainer.num_val_batches` > 0 - `trainer.check_val_every_n_epoch` schedule matches the upcoming epoch - Perform deferred saves in [on_validation_end], ensuring fresh validation metrics are used. - Allow zero `timedelta` for `train_time_interval` and broadcast the time-trigger decision across ranks. - Do not defer when monitoring a train metric or when no validation is scheduled. Tests: - Repro (previously failing, now passing): - [tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py] - Additional validations: - [tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py] - [tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py] Outcome: - `best_model_score` matches the validation metric after the epoch. - Step/time-interval checkpointing behaves correctly without premature or skipped saves. * test: disable logger in model checkpoint tests to avoid side effects * chlog --------- Co-authored-by: Jirka B <[email protected]>
1 parent d85c474 commit b1cc925

File tree

5 files changed

+543
-4
lines changed

5 files changed

+543
-4
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828

2929
### Fixed
3030

31-
-
31+
- Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106))
32+
3233

3334

3435
---

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,9 @@ def __init__(
260260
self.best_model_path = ""
261261
self.last_model_path = ""
262262
self._last_checkpoint_saved = ""
263+
# When using step/time-based checkpointing with a validation-only monitored metric,
264+
# defer the save until validation has produced the metric
265+
self._defer_save_until_validation: bool = False
263266

264267
self.kth_value: Tensor
265268
self.dirpath: Optional[_PATH]
@@ -306,14 +309,17 @@ def on_train_batch_end(
306309
batch_idx: int,
307310
) -> None:
308311
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
309-
if self._should_skip_saving_checkpoint(trainer):
310-
return
312+
# Do not return early here because we may need to set deferral flags even
313+
# if a save already happened at this global step. We'll enforce the skip
314+
# just before actually saving below.
315+
skip_due_to_state = self._should_skip_saving_checkpoint(trainer)
311316
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
312317

313318
train_time_interval = self._train_time_interval
314319
skip_time = True
315320
now = time.monotonic()
316-
if train_time_interval:
321+
# Important: allow zero timedelta as a valid interval
322+
if train_time_interval is not None:
317323
prev_time_check = self._last_time_checked
318324
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
319325
# in case we have time differences across ranks
@@ -326,6 +332,42 @@ def on_train_batch_end(
326332
self._last_time_checked = now
327333

328334
monitor_candidates = self._monitor_candidates(trainer)
335+
# If monitoring a metric that is not yet available (e.g., validation-only),
336+
# defer saving until validation end so the metric is present.
337+
if self.monitor is not None and self.monitor not in monitor_candidates:
338+
# Defer both top-k and last to avoid blocking with `_last_global_step_saved`
339+
self._defer_save_until_validation = True
340+
return
341+
342+
# Even if the monitored key exists, it could be stale from a previous validation.
343+
# If validation is scheduled to run right after this batch (e.g., last batch of epoch)
344+
# and we are not saving at train epoch end, defer to `on_validation_end` to use fresh metrics.
345+
if (
346+
self.monitor is not None
347+
and not self._should_save_on_train_epoch_end(trainer)
348+
and getattr(trainer.fit_loop.epoch_loop.batch_progress, "is_last_batch", False)
349+
):
350+
# Only defer if a validation loop is expected to run after this batch.
351+
will_run_val = False
352+
if getattr(trainer, "enable_validation", False):
353+
num_val_batches = (
354+
sum(trainer.num_val_batches)
355+
if isinstance(trainer.num_val_batches, list)
356+
else trainer.num_val_batches
357+
)
358+
if num_val_batches and num_val_batches > 0:
359+
cve = trainer.check_val_every_n_epoch
360+
if cve is None or ((trainer.current_epoch + 1) % cve == 0):
361+
will_run_val = True
362+
363+
if will_run_val:
364+
self._defer_save_until_validation = True
365+
return
366+
367+
# Only proceed to save if not skipping due to trainer/callback state
368+
if skip_due_to_state:
369+
return
370+
329371
self._save_topk_checkpoint(trainer, monitor_candidates)
330372
self._save_last_checkpoint(trainer, monitor_candidates)
331373

@@ -343,6 +385,14 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
343385
"""Save a checkpoint at the end of the validation stage."""
344386
if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer):
345387
monitor_candidates = self._monitor_candidates(trainer)
388+
# If a step/time-triggered save was deferred due to a missing monitored metric,
389+
# perform the save now that validation metrics are available.
390+
if self._defer_save_until_validation:
391+
self._save_topk_checkpoint(trainer, monitor_candidates)
392+
self._save_last_checkpoint(trainer, monitor_candidates)
393+
self._defer_save_until_validation = False
394+
return
395+
346396
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
347397
self._save_topk_checkpoint(trainer, monitor_candidates)
348398
self._save_last_checkpoint(trainer, monitor_candidates)
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import math
2+
from datetime import timedelta
3+
4+
import pytest
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
from torch.utils.data import DataLoader, Dataset
9+
10+
from lightning.pytorch import LightningModule, Trainer, seed_everything
11+
from lightning.pytorch.callbacks import ModelCheckpoint
12+
13+
14+
class TinyDataset(Dataset):
15+
def __init__(self, n: int = 4):
16+
self.x = torch.arange(n, dtype=torch.float32).view(-1, 1)
17+
self.y = self.x.clone()
18+
19+
def __len__(self):
20+
return len(self.x)
21+
22+
def __getitem__(self, idx):
23+
return self.x[idx], self.y[idx]
24+
25+
26+
class TrainMetricModule(LightningModule):
27+
def __init__(self):
28+
super().__init__()
29+
self.layer = nn.Linear(1, 1)
30+
self._counter = 0.0
31+
32+
def training_step(self, batch, batch_idx):
33+
x, y = batch
34+
y_hat = self.layer(x)
35+
loss = F.mse_loss(y_hat, y)
36+
# strictly increasing train metric per step
37+
self._counter += 1.0
38+
self.log("train_score", torch.tensor(self._counter), on_step=True, on_epoch=False, prog_bar=False, logger=True)
39+
return {"loss": loss}
40+
41+
def validation_step(self, batch, batch_idx):
42+
pass
43+
44+
def configure_optimizers(self):
45+
return torch.optim.SGD(self.parameters(), lr=0.01)
46+
47+
48+
def _make_loaders(n=4):
49+
ds = TinyDataset(n=n)
50+
train_loader = DataLoader(ds, batch_size=2, shuffle=False)
51+
val_loader = DataLoader(ds, batch_size=2, shuffle=False)
52+
return train_loader, val_loader
53+
54+
55+
def test_model_checkpoint_every_n_train_steps_with_train_metric_saves_at_step(tmp_path):
56+
"""When monitoring a train-step metric, step-interval checkpointing should save at the step boundary (no deferral)
57+
and best_model_score should match the last train metric value."""
58+
seed_everything(123)
59+
60+
train_loader, val_loader = _make_loaders(n=4)
61+
model = TrainMetricModule()
62+
63+
ckpt = ModelCheckpoint(
64+
dirpath=tmp_path,
65+
monitor="train_score",
66+
mode="max",
67+
save_top_k=1,
68+
every_n_train_steps=1,
69+
train_time_interval=None,
70+
every_n_epochs=0,
71+
save_on_train_epoch_end=False,
72+
save_weights_only=True,
73+
)
74+
75+
# 2 batches/epoch, run 2 epochs to have multiple step saves
76+
trainer = Trainer(
77+
max_epochs=2,
78+
accelerator="cpu",
79+
devices=1,
80+
callbacks=[ckpt],
81+
num_sanity_val_steps=0,
82+
log_every_n_steps=1,
83+
limit_train_batches=2,
84+
limit_val_batches=0, # no validation needed for this test
85+
enable_checkpointing=True,
86+
enable_model_summary=False,
87+
logger=False,
88+
)
89+
90+
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
91+
92+
assert ckpt.best_model_score is not None
93+
# 2 epochs * 2 steps/epoch = 4 steps total; metric increments by 1 each step
94+
expected = 4.0
95+
actual = float(ckpt.best_model_score)
96+
assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)
97+
98+
99+
@pytest.mark.parametrize("val_scores", [[0.2, 0.4, 0.9]])
100+
def test_model_checkpoint_time_interval_with_val_metric_defers_until_validation(tmp_path, val_scores):
101+
"""With time-interval-based checkpointing, and a validation-only metric, ensure we don't save using stale metrics
102+
at step boundaries; saving should occur at validation end."""
103+
seed_everything(123)
104+
105+
train_loader, val_loader = _make_loaders(n=4)
106+
107+
model = ValMetricModule(val_scores=val_scores)
108+
109+
ckpt = ModelCheckpoint(
110+
dirpath=tmp_path,
111+
monitor="auroc",
112+
mode="max",
113+
save_top_k=1,
114+
every_n_train_steps=0, # disable step-based
115+
train_time_interval=timedelta(seconds=0), # trigger as often as possible
116+
every_n_epochs=0,
117+
save_on_train_epoch_end=False,
118+
save_weights_only=True,
119+
)
120+
121+
trainer = Trainer(
122+
max_epochs=len(val_scores),
123+
accelerator="cpu",
124+
devices=1,
125+
callbacks=[ckpt],
126+
num_sanity_val_steps=0,
127+
log_every_n_steps=1,
128+
limit_train_batches=2,
129+
limit_val_batches=1,
130+
enable_checkpointing=True,
131+
enable_model_summary=False,
132+
logger=False,
133+
)
134+
135+
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
136+
137+
assert ckpt.best_model_score is not None
138+
expected = max(val_scores)
139+
actual = float(ckpt.best_model_score)
140+
assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)
141+
142+
143+
class ValMetricModule(LightningModule):
144+
def __init__(self, val_scores: list[float]):
145+
super().__init__()
146+
self.layer = nn.Linear(1, 1)
147+
self._val_scores = [float(s) for s in val_scores]
148+
149+
def training_step(self, batch, batch_idx):
150+
x, y = batch
151+
y_hat = self.layer(x)
152+
loss = F.mse_loss(y_hat, y)
153+
return {"loss": loss}
154+
155+
def validation_step(self, batch, batch_idx):
156+
pass
157+
158+
def on_validation_epoch_end(self):
159+
score = self._val_scores[self.current_epoch]
160+
self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True)
161+
162+
def configure_optimizers(self):
163+
return torch.optim.SGD(self.parameters(), lr=0.01)
164+
165+
166+
@pytest.mark.parametrize("val_scores", [[0.1, 0.5, 1.0, 3.0]])
167+
def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tmp_path, val_scores):
168+
"""With validation running every 2 epochs, step-triggered saves at the end of non-validation epochs should be
169+
deferred and then performed at the next validation end when the metric is available."""
170+
seed_everything(123)
171+
172+
train_loader, val_loader = _make_loaders(n=4)
173+
174+
model = ValMetricModule(val_scores=val_scores)
175+
176+
ckpt = ModelCheckpoint(
177+
dirpath=tmp_path,
178+
monitor="auroc",
179+
mode="max",
180+
save_top_k=1,
181+
every_n_train_steps=2, # end of each epoch
182+
train_time_interval=None,
183+
every_n_epochs=0,
184+
save_on_train_epoch_end=False,
185+
save_weights_only=True,
186+
)
187+
188+
trainer = Trainer(
189+
max_epochs=len(val_scores),
190+
accelerator="cpu",
191+
devices=1,
192+
callbacks=[ckpt],
193+
num_sanity_val_steps=0,
194+
log_every_n_steps=1,
195+
limit_train_batches=2,
196+
limit_val_batches=1,
197+
enable_checkpointing=True,
198+
enable_model_summary=False,
199+
logger=False,
200+
check_val_every_n_epoch=2, # only validate every 2 epochs
201+
)
202+
203+
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
204+
205+
assert ckpt.best_model_score is not None
206+
expected = max(val_scores) # last/maximum value occurs at final validation epoch
207+
actual = float(ckpt.best_model_score)
208+
assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)

0 commit comments

Comments
 (0)