Skip to content

Commit 59d2600

Browse files
authored
Make saving 'last' checkpoint as symbolic link opt-in (#19191)
1 parent c3e2ba5 commit 59d2600

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121
- Added `TransformerEnginePrecision(fallback_compute_dtype=)` to control the dtype of operations that don't support fp8 ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
2222

2323

24+
- Added the option `ModelCheckpoint(save_last='link')` to create a symbolic link for the 'last.ckpt' file ([#19191](https://github.com/Lightning-AI/lightning/pull/19191))
25+
26+
2427
### Changed
2528

2629
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
@@ -47,6 +50,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4750
- The columns in the `metrics.csv` file produced by `CSVLogger` are now sorted alphabetically ([#19159](https://github.com/Lightning-AI/lightning/pull/19159))
4851

4952

53+
- Reverted back to creating a checkpoint copy when `ModelCheckpoint(save_last=True)` instead of creating a symbolic link ([#19191](https://github.com/Lightning-AI/lightning/pull/19191))
54+
55+
5056
### Deprecated
5157

5258
- Deprecated all precision plugin classes under `lightning.pytorch.plugins` with the suffix `Plugin` in the name ([#18840](https://github.com/Lightning-AI/lightning/pull/18840))

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from copy import deepcopy
2727
from datetime import timedelta
2828
from pathlib import Path
29-
from typing import Any, Dict, Optional, Set
29+
from typing import Any, Dict, Literal, Optional, Set
3030
from weakref import proxy
3131

3232
import torch
@@ -83,9 +83,9 @@ class ModelCheckpoint(Checkpoint):
8383
the number of finished epoch and optimizer steps respectively.
8484
monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
8585
verbose: verbosity mode. Default: ``False``.
86-
save_last: When ``True``, saves a `last.ckpt` whenever a checkpoint file gets saved. On a local filesystem,
87-
this will be a symbolic link, and otherwise a copy of the checkpoint file. This allows accessing the latest
88-
checkpoint in a deterministic manner. Default: ``None``.
86+
save_last: When ``True``, saves a `last.ckpt` copy whenever a checkpoint file gets saved. Can be set to
87+
``'link'`` on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint
88+
in a deterministic manner. Default: ``None``.
8989
save_top_k: if ``save_top_k == k``,
9090
the best k models according to the quantity monitored will be saved.
9191
if ``save_top_k == 0``, no models are saved.
@@ -216,7 +216,7 @@ def __init__(
216216
filename: Optional[str] = None,
217217
monitor: Optional[str] = None,
218218
verbose: bool = False,
219-
save_last: Optional[bool] = None,
219+
save_last: Optional[Literal[True, False, "link"]] = None,
220220
save_top_k: int = 1,
221221
save_weights_only: bool = False,
222222
mode: str = "min",
@@ -272,6 +272,10 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
272272
self._fs = get_filesystem(self.dirpath or "")
273273
if trainer.is_global_zero and stage == "fit":
274274
self.__warn_if_dir_not_empty(self.dirpath)
275+
if self.save_last == "link" and not _is_local_file_protocol(self.dirpath):
276+
raise ValueError(
277+
f"`ModelCheckpoint(save_last='link')` is only supported for local file paths, got `dirpath={dirpath}`."
278+
)
275279

276280
@override
277281
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -684,7 +688,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
684688

685689
# set the last model path before saving because it will be part of the state.
686690
previous, self.last_model_path = self.last_model_path, filepath
687-
if _is_local_file_protocol(filepath) and self._last_checkpoint_saved and self.save_top_k != 0:
691+
if self.save_last == "link" and self._last_checkpoint_saved and self.save_top_k != 0:
688692
self._link_checkpoint(trainer, self._last_checkpoint_saved, filepath)
689693
else:
690694
self._save_checkpoint(trainer, filepath)

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -485,13 +485,14 @@ def test_model_checkpoint_file_extension(tmpdir):
485485
assert set(expected) == set(os.listdir(tmpdir))
486486

487487

488-
def test_model_checkpoint_save_last(tmpdir, monkeypatch):
488+
@pytest.mark.parametrize("save_last", [True, "link"])
489+
def test_model_checkpoint_save_last(save_last, tmpdir, monkeypatch):
489490
"""Tests that save_last produces only one last checkpoint."""
490491
seed_everything()
491492
model = LogInTwoMethods()
492493
epochs = 3
493494
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_NAME_LAST", "last-{epoch}")
494-
model_checkpoint = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, save_top_k=-1, save_last=True)
495+
model_checkpoint = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, save_top_k=-1, save_last=save_last)
495496
trainer = Trainer(
496497
default_root_dir=tmpdir,
497498
callbacks=[model_checkpoint],
@@ -509,10 +510,19 @@ def test_model_checkpoint_save_last(tmpdir, monkeypatch):
509510
assert set(os.listdir(tmpdir)) == set(
510511
[f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20, 30])] + [last_filename]
511512
)
512-
assert os.path.islink(tmpdir / last_filename)
513+
if save_last == "link":
514+
assert os.path.islink(tmpdir / last_filename)
515+
else:
516+
assert os.path.isfile(tmpdir / last_filename)
513517
assert os.path.realpath(tmpdir / last_filename) == model_checkpoint._last_checkpoint_saved
514518

515519

520+
def test_model_checkpoint_save_last_as_link_not_local(tmp_path):
521+
callback = ModelCheckpoint(dirpath="memory://not-a-filesystem-path", save_last="link")
522+
with pytest.raises(ValueError, match="save_last='link'.* is only supported for local file paths"):
523+
callback.setup(trainer=Trainer(), pl_module=BoringModel(), stage="fit")
524+
525+
516526
def test_model_checkpoint_link_checkpoint(tmp_path):
517527
"""Test that linking a checkpoint works and overwrites an existing link if present."""
518528
trainer = Mock()
@@ -676,7 +686,7 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
676686
expected = [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20])]
677687
expected.append("last.ckpt")
678688
assert set(os.listdir(tmpdir)) == set(expected)
679-
assert os.path.islink(tmpdir / "last.ckpt")
689+
assert os.path.isfile(tmpdir / "last.ckpt")
680690

681691

682692
@pytest.mark.parametrize("every_n_epochs", list(range(4)))
@@ -887,7 +897,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
887897
path_last = str(tmpdir / "last.ckpt")
888898
assert path_last == model_checkpoint.last_model_path
889899
assert os.path.isfile(path_last_epoch)
890-
assert os.path.islink(path_last)
900+
assert os.path.isfile(path_last)
891901

892902
ckpt_last_epoch = torch.load(path_last_epoch)
893903
ckpt_last = torch.load(path_last)

0 commit comments

Comments
 (0)