Skip to content

Commit 3bc6e9f

Browse files
committed
updated model_checkpoint.py to add the facility of retaining periodic checkpoints
1 parent a944e77 commit 3bc6e9f

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ class ModelCheckpoint(Checkpoint):
136136
If this is ``False``, then the check runs at the end of the validation.
137137
enable_version_counter: Whether to append a version to the existing file name.
138138
If this is ``False``, then the checkpoint files will be overwritten.
139+
retain_periodic_ckpt: Whether to retain the periodic checkpoints when multiple checkpoints are
140+
saved. If this is ``False``, then only the latest checkpoint will be saved. If this is ``True``,
141+
don't change the default value of ``save_top_k``.
142+
Default: ``False``.
139143
140144
Note:
141145
For extra customization, ModelCheckpoint includes the following attributes:
@@ -228,6 +232,7 @@ def __init__(
228232
every_n_epochs: Optional[int] = None,
229233
save_on_train_epoch_end: Optional[bool] = None,
230234
enable_version_counter: bool = True,
235+
retain_periodic_ckpt: bool = False,
231236
):
232237
super().__init__()
233238
self.monitor = monitor
@@ -247,6 +252,7 @@ def __init__(
247252
self.best_model_path = ""
248253
self.last_model_path = ""
249254
self._last_checkpoint_saved = ""
255+
self.retain_periodic_ckpt = retain_periodic_ckpt
250256

251257
self.kth_value: Tensor
252258
self.dirpath: Optional[_PATH]
@@ -714,7 +720,12 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
714720
previous, self.best_model_path = self.best_model_path, filepath
715721
self._save_checkpoint(trainer, filepath)
716722

717-
if self.save_top_k == 1 and previous and self._should_remove_checkpoint(trainer, previous, filepath):
723+
if (
724+
self.save_top_k == 1
725+
and not self.retain_periodic_ckpt
726+
and previous
727+
and self._should_remove_checkpoint(trainer, previous, filepath)
728+
):
718729
self._remove_checkpoint(trainer, previous)
719730

720731
def _update_best_and_save(

0 commit comments

Comments
 (0)