Skip to content

Commit aea989c

Browse files
baskrahmerBorda
authored andcommitted
Fix save_last behavior in absence of validation (#20960)
* Implement test to reproduce the issue * Implement fix * chlog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka B <[email protected]> (cherry picked from commit 72bb751)
1 parent 38c4059 commit aea989c

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4141
- Fixed metrics in `RichProgressBar` being updated according to user provided `refresh_rate` ([#21032](https://github.com/Lightning-AI/pytorch-lightning/pull/21032))
4242

4343

44+
- Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960))
45+
46+
4447
---
4548

4649
## [2.5.2] - 2025-06-20

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,13 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
344344
self._save_topk_checkpoint(trainer, monitor_candidates)
345345
self._save_last_checkpoint(trainer, monitor_candidates)
346346

347+
@override
348+
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
349+
"""Ensure save_last=True is applied when training ends."""
350+
if self.save_last and not self._last_checkpoint_saved:
351+
monitor_candidates = self._monitor_candidates(trainer)
352+
self._save_last_checkpoint(trainer, monitor_candidates)
353+
347354
@override
348355
def state_dict(self) -> dict[str, Any]:
349356
return {

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,3 +1666,30 @@ def val_dataloader(self) -> DataLoader:
16661666
trainer_kwargs["max_epochs"] = 4
16671667
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
16681668
trainer.fit(model, ckpt_path=checkpoint_path)
1669+
1670+
1671+
def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path):
1672+
"""Test that save_last=True works correctly when save_on_train_epoch_end=False in a model without validation."""
1673+
1674+
# Remove validation methods to test the edge case
1675+
model = BoringModel()
1676+
model.validation_step = None
1677+
model.val_dataloader = None
1678+
1679+
checkpoint_callback = ModelCheckpoint(
1680+
dirpath=tmp_path,
1681+
save_last=True,
1682+
save_on_train_epoch_end=False,
1683+
)
1684+
1685+
trainer = Trainer(
1686+
max_epochs=2,
1687+
callbacks=[checkpoint_callback],
1688+
logger=False,
1689+
enable_progress_bar=False,
1690+
)
1691+
1692+
trainer.fit(model)
1693+
1694+
# save_last=True should always save last.ckpt
1695+
assert (tmp_path / "last.ckpt").exists()

0 commit comments

Comments
 (0)