diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index fe9173d008230..176e34273d776 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -28,7 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106)) + --- diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 68fed2ff82d31..452e8bdecbba3 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -260,6 +260,9 @@ def __init__( self.best_model_path = "" self.last_model_path = "" self._last_checkpoint_saved = "" + # When using step/time-based checkpointing with a validation-only monitored metric, + # defer the save until validation has produced the metric + self._defer_save_until_validation: bool = False self.kth_value: Tensor self.dirpath: Optional[_PATH] @@ -306,14 +309,17 @@ def on_train_batch_end( batch_idx: int, ) -> None: """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`""" - if self._should_skip_saving_checkpoint(trainer): - return + # Do not return early here because we may need to set deferral flags even + # if a save already happened at this global step. We'll enforce the skip + # just before actually saving below. + skip_due_to_state = self._should_skip_saving_checkpoint(trainer) skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0) train_time_interval = self._train_time_interval skip_time = True now = time.monotonic() - if train_time_interval: + # Important: allow zero timedelta as a valid interval + if train_time_interval is not None: prev_time_check = self._last_time_checked skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds() # in case we have time differences across ranks @@ -326,6 +332,42 @@ def on_train_batch_end( self._last_time_checked = now monitor_candidates = self._monitor_candidates(trainer) + # If monitoring a metric that is not yet available (e.g., validation-only), + # defer saving until validation end so the metric is present. + if self.monitor is not None and self.monitor not in monitor_candidates: + # Defer both top-k and last to avoid blocking with `_last_global_step_saved` + self._defer_save_until_validation = True + return + + # Even if the monitored key exists, it could be stale from a previous validation. + # If validation is scheduled to run right after this batch (e.g., last batch of epoch) + # and we are not saving at train epoch end, defer to `on_validation_end` to use fresh metrics. + if ( + self.monitor is not None + and not self._should_save_on_train_epoch_end(trainer) + and getattr(trainer.fit_loop.epoch_loop.batch_progress, "is_last_batch", False) + ): + # Only defer if a validation loop is expected to run after this batch. + will_run_val = False + if getattr(trainer, "enable_validation", False): + num_val_batches = ( + sum(trainer.num_val_batches) + if isinstance(trainer.num_val_batches, list) + else trainer.num_val_batches + ) + if num_val_batches and num_val_batches > 0: + cve = trainer.check_val_every_n_epoch + if cve is None or ((trainer.current_epoch + 1) % cve == 0): + will_run_val = True + + if will_run_val: + self._defer_save_until_validation = True + return + + # Only proceed to save if not skipping due to trainer/callback state + if skip_due_to_state: + return + self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) @@ -343,6 +385,14 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul """Save a checkpoint at the end of the validation stage.""" if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer): monitor_candidates = self._monitor_candidates(trainer) + # If a step/time-triggered save was deferred due to a missing monitored metric, + # perform the save now that validation metrics are available. + if self._defer_save_until_validation: + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) + self._defer_save_until_validation = False + return + if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py new file mode 100644 index 0000000000000..fba9e865debfd --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py @@ -0,0 +1,208 @@ +import math +from datetime import timedelta + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset + +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint + + +class TinyDataset(Dataset): + def __init__(self, n: int = 4): + self.x = torch.arange(n, dtype=torch.float32).view(-1, 1) + self.y = self.x.clone() + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + +class TrainMetricModule(LightningModule): + def __init__(self): + super().__init__() + self.layer = nn.Linear(1, 1) + self._counter = 0.0 + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + # strictly increasing train metric per step + self._counter += 1.0 + self.log("train_score", torch.tensor(self._counter), on_step=True, on_epoch=False, prog_bar=False, logger=True) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + pass + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +def _make_loaders(n=4): + ds = TinyDataset(n=n) + train_loader = DataLoader(ds, batch_size=2, shuffle=False) + val_loader = DataLoader(ds, batch_size=2, shuffle=False) + return train_loader, val_loader + + +def test_model_checkpoint_every_n_train_steps_with_train_metric_saves_at_step(tmp_path): + """When monitoring a train-step metric, step-interval checkpointing should save at the step boundary (no deferral) + and best_model_score should match the last train metric value.""" + seed_everything(123) + + train_loader, val_loader = _make_loaders(n=4) + model = TrainMetricModule() + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="train_score", + mode="max", + save_top_k=1, + every_n_train_steps=1, + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + # 2 batches/epoch, run 2 epochs to have multiple step saves + trainer = Trainer( + max_epochs=2, + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=0, # no validation needed for this test + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + # 2 epochs * 2 steps/epoch = 4 steps total; metric increments by 1 each step + expected = 4.0 + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) + + +@pytest.mark.parametrize("val_scores", [[0.2, 0.4, 0.9]]) +def test_model_checkpoint_time_interval_with_val_metric_defers_until_validation(tmp_path, val_scores): + """With time-interval-based checkpointing, and a validation-only metric, ensure we don't save using stale metrics + at step boundaries; saving should occur at validation end.""" + seed_everything(123) + + train_loader, val_loader = _make_loaders(n=4) + + model = ValMetricModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + every_n_train_steps=0, # disable step-based + train_time_interval=timedelta(seconds=0), # trigger as often as possible + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=len(val_scores), + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + expected = max(val_scores) + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) + + +class ValMetricModule(LightningModule): + def __init__(self, val_scores: list[float]): + super().__init__() + self.layer = nn.Linear(1, 1) + self._val_scores = [float(s) for s in val_scores] + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + pass + + def on_validation_epoch_end(self): + score = self._val_scores[self.current_epoch] + self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +@pytest.mark.parametrize("val_scores", [[0.1, 0.5, 1.0, 3.0]]) +def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tmp_path, val_scores): + """With validation running every 2 epochs, step-triggered saves at the end of non-validation epochs should be + deferred and then performed at the next validation end when the metric is available.""" + seed_everything(123) + + train_loader, val_loader = _make_loaders(n=4) + + model = ValMetricModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + every_n_train_steps=2, # end of each epoch + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=len(val_scores), + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + check_val_every_n_epoch=2, # only validate every 2 epochs + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + expected = max(val_scores) # last/maximum value occurs at final validation epoch + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py new file mode 100644 index 0000000000000..a265f8bc5f194 --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py @@ -0,0 +1,174 @@ +import math +from datetime import timedelta + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset + +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint + + +class TinyDataset(Dataset): + def __init__(self, n: int = 8): + self.x = torch.arange(n, dtype=torch.float32).view(-1, 1) + self.y = self.x.clone() + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + +def _make_loaders(n=8, batch_size=2): + ds = TinyDataset(n=n) + train_loader = DataLoader(ds, batch_size=batch_size, shuffle=False) + val_loader = DataLoader(ds, batch_size=batch_size, shuffle=False) + return train_loader, val_loader + + +class MultiValPerEpochModule(LightningModule): + """Logs a validation metric on every validation run, even if validation is run multiple times per epoch.""" + + def __init__(self, val_scores: list[float]): + super().__init__() + self.layer = nn.Linear(1, 1) + self._val_scores = [float(s) for s in val_scores] + self._val_call_idx = 0 + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + pass + + def on_validation_epoch_end(self): + score = self._val_scores[self._val_call_idx] + self._val_call_idx += 1 + self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +class ValOnceEveryTwoEpochsModule(LightningModule): + """Logs a validation metric only when validation runs (e.g., every 2 epochs), indexed by current_epoch.""" + + def __init__(self, val_scores: list[float]): + super().__init__() + self.layer = nn.Linear(1, 1) + self._val_scores = [float(s) for s in val_scores] + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + pass + + def on_validation_epoch_end(self): + # current_epoch indexes into provided scores; only called when validation runs + score = self._val_scores[self.current_epoch] + self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +@pytest.mark.parametrize("val_scores", [[0.1, 0.9]]) +def test_checkpoint_defers_with_mid_epoch_validation(tmp_path, val_scores): + """With val_check_interval=0.5 (validation mid-epoch and at epoch end), and step-based checkpointing, saves must be + deferred until each validation end so monitored validation metrics are fresh.""" + seed_everything(123) + + # 4 train batches per epoch (batch_size=2 over n=8), so two validations: after 2 batches and after 4 batches + train_loader, val_loader = _make_loaders(n=8, batch_size=2) + + model = MultiValPerEpochModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + every_n_train_steps=1, # would trigger every step, but must defer to validation + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=1, + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=4, # ensure exactly 4 steps => two validations at 0.5 and 1.0 + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + val_check_interval=0.5, + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + expected = max(val_scores) + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) + + +@pytest.mark.parametrize("val_scores", [[0.2, 0.6]]) +def test_time_interval_defers_across_epoch_until_first_validation(tmp_path, val_scores): + """With time-interval saving and validation only every 2 epochs, ensure no save uses stale/missing validation + metrics; the first save should happen at the first validation end (epoch 2).""" + seed_everything(123) + + train_loader, val_loader = _make_loaders(n=4, batch_size=2) + + model = ValOnceEveryTwoEpochsModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + every_n_train_steps=0, # disable step-based + train_time_interval=timedelta(seconds=0), # trigger frequently + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=2, + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + check_val_every_n_epoch=2, # first validation only after 2nd epoch + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + expected = val_scores[1] # validation runs only once at epoch 2, logging index 1 + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py new file mode 100644 index 0000000000000..c3fa0bfcd2e38 --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py @@ -0,0 +1,106 @@ +import math + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset + +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint + + +class TinyDataset(Dataset): + def __init__(self, n: int = 4): + self.x = torch.arange(n, dtype=torch.float32).view(-1, 1) + self.y = self.x.clone() + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + +class ValMetricModule(LightningModule): + def __init__(self, val_scores: list[float]): + super().__init__() + self.layer = nn.Linear(1, 1) + self._val_scores = [float(s) for s in val_scores] + + # LightningModule API (minimal) + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + # do nothing per-step; we log at epoch end + pass + + def on_validation_epoch_end(self): + # Log a validation metric only at validation epoch end + # Values increase across epochs; best should be the last epoch + score = self._val_scores[self.current_epoch] + # use logger=True so it lands in trainer.callback_metrics + self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +@pytest.mark.parametrize("val_scores", [[0.1, 0.5, 1.0]]) +def test_model_checkpoint_every_n_train_steps_with_val_metric_saves_after_val(tmp_path, val_scores): + """Reproduces #20919: Using every_n_train_steps with a validation-only metric should save the best checkpoint only + after the metric is computed at validation, not earlier at the train-step boundary. + + Expectation: best_model_score equals the last (max) val score. + + """ + seed_everything(123) + + # 2 train batches per epoch (so checkpoint triggers at the epoch boundary) + ds = TinyDataset(n=4) + train_loader = DataLoader(ds, batch_size=2, shuffle=False) + val_loader = DataLoader(ds, batch_size=2, shuffle=False) + + model = ValMetricModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + # critical: trigger on train steps, not on epoch end + every_n_train_steps=2, # equal to number of train batches per epoch + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=len(val_scores), + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + # Should equal the last (max) validation score + expected = max(val_scores) + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6), ( + f"best_model_score should be {expected} (last/maximum val score), got {actual}.\n" + f"This indicates the checkpoint was saved before the validation metric was computed." + )