Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,13 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)

@override
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Ensure save_last=True is applied when training ends."""
if self.save_last and not self._last_checkpoint_saved:
monitor_candidates = self._monitor_candidates(trainer)
self._save_last_checkpoint(trainer, monitor_candidates)

@override
def state_dict(self) -> dict[str, Any]:
return {
Expand Down
27 changes: 27 additions & 0 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,3 +1666,30 @@ def val_dataloader(self) -> DataLoader:
trainer_kwargs["max_epochs"] = 4
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
trainer.fit(model, ckpt_path=checkpoint_path)


def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path):
"""Test that save_last=True works correctly when save_on_train_epoch_end=False in a model without validation."""

# Remove validation methods to test the edge case
model = BoringModel()
model.validation_step = None
model.val_dataloader = None

checkpoint_callback = ModelCheckpoint(
dirpath=tmp_path,
save_last=True,
save_on_train_epoch_end=False,
)

trainer = Trainer(
max_epochs=2,
callbacks=[checkpoint_callback],
logger=False,
enable_progress_bar=False,
)

trainer.fit(model)

# save_last=True should always save last.ckpt
assert (tmp_path / "last.ckpt").exists()
Loading