Skip to content

Commit b2b9efe

Browse files
committed
fix implementation
1 parent e088694 commit b2b9efe

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,10 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
380380
monitor_candidates = self._monitor_candidates(trainer)
381381
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
382382
self._save_topk_checkpoint(trainer, monitor_candidates)
383-
self._save_last_checkpoint(trainer, monitor_candidates)
383+
# Only save last checkpoint if a checkpoint was actually saved in this step or if save_last="link"
384+
if (self._last_global_step_saved == trainer.global_step or
385+
(self.save_last == "link" and self._last_checkpoint_saved)):
386+
self._save_last_checkpoint(trainer, monitor_candidates)
384387

385388
@override
386389
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -397,7 +400,10 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
397400

398401
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
399402
self._save_topk_checkpoint(trainer, monitor_candidates)
400-
self._save_last_checkpoint(trainer, monitor_candidates)
403+
# Only save last checkpoint if a checkpoint was actually saved in this step or if save_last="link"
404+
if (self._last_global_step_saved == trainer.global_step or
405+
(self.save_last == "link" and self._last_checkpoint_saved)):
406+
self._save_last_checkpoint(trainer, monitor_candidates)
401407

402408
@override
403409
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
@@ -902,3 +908,5 @@ def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, curren
902908
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
903909
"""Calls the strategy to remove the checkpoint file."""
904910
trainer.strategy.remove_checkpoint(filepath)
911+
912+

0 commit comments

Comments
 (0)