|
26 | 26 | from copy import deepcopy
|
27 | 27 | from datetime import timedelta
|
28 | 28 | from pathlib import Path
|
29 |
| -from typing import Any, Dict, Optional, Set |
| 29 | +from typing import Any, Dict, Literal, Optional, Set |
30 | 30 | from weakref import proxy
|
31 | 31 |
|
32 | 32 | import torch
|
@@ -83,9 +83,9 @@ class ModelCheckpoint(Checkpoint):
|
83 | 83 | the number of finished epoch and optimizer steps respectively.
|
84 | 84 | monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
|
85 | 85 | verbose: verbosity mode. Default: ``False``.
|
86 |
| - save_last: When ``True``, saves a `last.ckpt` whenever a checkpoint file gets saved. On a local filesystem, |
87 |
| - this will be a symbolic link, and otherwise a copy of the checkpoint file. This allows accessing the latest |
88 |
| - checkpoint in a deterministic manner. Default: ``None``. |
| 86 | + save_last: When ``True``, saves a `last.ckpt` copy whenever a checkpoint file gets saved. Can be set to |
| 87 | + ``'link'`` on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint |
| 88 | + in a deterministic manner. Default: ``None``. |
89 | 89 | save_top_k: if ``save_top_k == k``,
|
90 | 90 | the best k models according to the quantity monitored will be saved.
|
91 | 91 | if ``save_top_k == 0``, no models are saved.
|
@@ -216,7 +216,7 @@ def __init__(
|
216 | 216 | filename: Optional[str] = None,
|
217 | 217 | monitor: Optional[str] = None,
|
218 | 218 | verbose: bool = False,
|
219 |
| - save_last: Optional[bool] = None, |
| 219 | + save_last: Optional[Literal[True, False, "link"]] = None, |
220 | 220 | save_top_k: int = 1,
|
221 | 221 | save_weights_only: bool = False,
|
222 | 222 | mode: str = "min",
|
@@ -272,6 +272,10 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
|
272 | 272 | self._fs = get_filesystem(self.dirpath or "")
|
273 | 273 | if trainer.is_global_zero and stage == "fit":
|
274 | 274 | self.__warn_if_dir_not_empty(self.dirpath)
|
| 275 | + if self.save_last == "link" and not _is_local_file_protocol(self.dirpath): |
| 276 | + raise ValueError( |
| 277 | + f"`ModelCheckpoint(save_last='link')` is only supported for local file paths, got `dirpath={dirpath}`." |
| 278 | + ) |
275 | 279 |
|
276 | 280 | @override
|
277 | 281 | def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
@@ -684,7 +688,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
|
684 | 688 |
|
685 | 689 | # set the last model path before saving because it will be part of the state.
|
686 | 690 | previous, self.last_model_path = self.last_model_path, filepath
|
687 |
| - if _is_local_file_protocol(filepath) and self._last_checkpoint_saved and self.save_top_k != 0: |
| 691 | + if self.save_last == "link" and self._last_checkpoint_saved and self.save_top_k != 0: |
688 | 692 | self._link_checkpoint(trainer, self._last_checkpoint_saved, filepath)
|
689 | 693 | else:
|
690 | 694 | self._save_checkpoint(trainer, filepath)
|
|
0 commit comments