Skip to content

Commit c5e3c45

Browse files
awaelchlicarmocca
andauthored
Save ModelCheckpoint's last.ckpt as symlink if possible (#18748)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 7434c47 commit c5e3c45

File tree

6 files changed

+44
-30
lines changed

6 files changed

+44
-30
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9090
- The `ModelCheckpoint` no longer deletes files under the save-top-k mechanism when resuming from a folder that is not the same as the current checkpoint folder ([#18750](https://github.com/Lightning-AI/lightning/pull/18750))
9191
- The `ModelCheckpoint` no longer deletes the file that was passed to `Trainer.fit(ckpt_path=...)` ([#18750](https://github.com/Lightning-AI/lightning/pull/18750))
9292
- Calling `trainer.fit()` twice now raises an error with strategies that spawn subprocesses through `multiprocessing` (ddp_spawn, xla) ([#18776](https://github.com/Lightning-AI/lightning/pull/18776))
93+
- The `ModelCheckpoint` now saves a symbolic link if `save_last=True` and `save_top_k != 0` ([#18748](https://github.com/Lightning-AI/lightning/pull/18748))
9394

9495
### Deprecated
9596

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ class ModelCheckpoint(Checkpoint):
8181
the number of finished epoch and optimizer steps respectively.
8282
monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
8383
verbose: verbosity mode. Default: ``False``.
84-
save_last: When ``True``, saves an exact copy of the checkpoint to a file `last.ckpt` whenever a checkpoint
85-
file gets saved. This allows accessing the latest checkpoint in a deterministic manner. Default: ``None``.
84+
save_last: When ``True``, saves a `last.ckpt` whenever a checkpoint file gets saved. On a local filesystem,
85+
this will be a symbolic link, and otherwise a copy of the checkpoint file. This allows accessing the latest
86+
checkpoint in a deterministic manner. Default: ``None``.
8687
save_top_k: if ``save_top_k == k``,
8788
the best k models according to the quantity monitored will be saved.
8889
if ``save_top_k == 0``, no models are saved.
@@ -241,6 +242,7 @@ def __init__(
241242
self.best_model_score: Optional[Tensor] = None
242243
self.best_model_path = ""
243244
self.last_model_path = ""
245+
self._last_checkpoint_saved = ""
244246

245247
self.kth_value: Tensor
246248
self.dirpath: Optional[_PATH]
@@ -371,12 +373,21 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
371373
trainer.save_checkpoint(filepath, self.save_weights_only)
372374

373375
self._last_global_step_saved = trainer.global_step
376+
self._last_checkpoint_saved = filepath
374377

375378
# notify loggers
376379
if trainer.is_global_zero:
377380
for logger in trainer.loggers:
378381
logger.after_save_checkpoint(proxy(self))
379382

383+
@staticmethod
384+
def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> None:
385+
if trainer.is_global_zero:
386+
if os.path.lexists(linkpath):
387+
os.remove(linkpath)
388+
os.symlink(filepath, linkpath)
389+
trainer.strategy.barrier()
390+
380391
def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
381392
from lightning.pytorch.trainer.states import TrainerFn
382393

@@ -427,19 +438,12 @@ def __validate_init_configuration(self) -> None:
427438
"should be mutually exclusive."
428439
)
429440

430-
if self.monitor is None:
441+
if self.monitor is None and self.save_top_k not in (-1, 0, 1):
431442
# -1: save all epochs, 0: nothing is saved, 1: save last epoch
432-
if self.save_top_k not in (-1, 0, 1):
433-
raise MisconfigurationException(
434-
f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid"
435-
" configuration. No quantity for top_k to track."
436-
)
437-
438-
if self.save_top_k == -1 and self.save_last:
439-
rank_zero_info(
440-
"ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)"
441-
" will duplicate the last checkpoint saved."
442-
)
443+
raise MisconfigurationException(
444+
f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid"
445+
" configuration. No quantity for top_k to track."
446+
)
443447

444448
def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
445449
self._fs = get_filesystem(dirpath if dirpath else "")
@@ -662,7 +666,10 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
662666

663667
# set the last model path before saving because it will be part of the state.
664668
previous, self.last_model_path = self.last_model_path, filepath
665-
self._save_checkpoint(trainer, filepath)
669+
if self._fs.protocol == "file" and self._last_checkpoint_saved and self.save_top_k != 0:
670+
self._link_checkpoint(trainer, self._last_checkpoint_saved, filepath)
671+
else:
672+
self._save_checkpoint(trainer, filepath)
666673
if previous and self._should_remove_checkpoint(trainer, previous, filepath):
667674
self._remove_checkpoint(trainer, previous)
668675

tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self):
6262
self.last_coeff = 10.0
6363

6464
def training_step(self, batch, batch_idx):
65-
loss = self.step(torch.ones(32))
65+
loss = self.step(torch.ones(32, device=self.device))
6666
loss = loss / (loss + 0.0000001)
6767
loss += self.last_coeff
6868
self.log("my_loss", loss)
@@ -80,8 +80,7 @@ def training_step(self, batch, batch_idx):
8080
trainer.fit(model)
8181

8282
if save_last:
83-
# last epochs are saved every step (so double the save calls)
84-
expected = expected * 2
83+
expected = expected
8584
assert save_mock.call_count == expected
8685

8786

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import time
1919
from argparse import Namespace
2020
from datetime import timedelta
21-
from logging import INFO
2221
from pathlib import Path
2322
from typing import Union
2423
from unittest import mock
@@ -510,7 +509,8 @@ def test_model_checkpoint_save_last(tmpdir):
510509
assert set(os.listdir(tmpdir)) == set(
511510
[f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20, 30])] + [last_filename]
512511
)
513-
512+
assert os.path.islink(tmpdir / last_filename)
513+
assert os.path.realpath(tmpdir / last_filename) == model_checkpoint._last_checkpoint_saved
514514
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last"
515515

516516

@@ -589,10 +589,7 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
589589
max_epochs=epochs,
590590
logger=False,
591591
)
592-
593-
with caplog.at_level(INFO):
594-
trainer.fit(model)
595-
assert "will duplicate the last checkpoint saved" in caplog.text
592+
trainer.fit(model)
596593

597594
# these should not be set if monitor is None
598595
assert checkpoint_callback.monitor is None
@@ -606,6 +603,7 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
606603
expected = [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20])]
607604
expected.append("last.ckpt")
608605
assert set(os.listdir(tmpdir)) == set(expected)
606+
assert os.path.islink(tmpdir / "last.ckpt")
609607

610608

611609
@pytest.mark.parametrize("every_n_epochs", list(range(4)))
@@ -709,6 +707,8 @@ def test_model_checkpoint_topk_zero(tmpdir):
709707
# check that only the last ckpt was created
710708
assert os.listdir(tmpdir) == ["last.ckpt"]
711709
assert checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
710+
# 'last.ckpt' is not a symlink because there are no top-k checkpoints to link
711+
assert not os.path.islink(checkpoint_callback.last_model_path)
712712

713713

714714
def test_model_checkpoint_topk_all(tmpdir):
@@ -814,6 +814,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
814814
path_last = str(tmpdir / "last.ckpt")
815815
assert path_last == model_checkpoint.last_model_path
816816
assert os.path.isfile(path_last_epoch)
817+
assert os.path.islink(path_last)
817818

818819
ckpt_last_epoch = torch.load(path_last_epoch)
819820
ckpt_last = torch.load(path_last)
@@ -1343,7 +1344,7 @@ def test_save_last_saves_correct_last_model_path(tmpdir):
13431344
trainer = Trainer(callbacks=mc)
13441345
trainer.strategy.connect(BoringModel())
13451346

1346-
mc._save_last_checkpoint(trainer, {"foo": 1})
1347+
mc._save_last_checkpoint(trainer, {"foo": torch.tensor(1)})
13471348
expected = "foo=1-last.ckpt"
13481349
assert os.listdir(tmpdir) == [expected]
13491350
full_path = str(tmpdir / expected)
@@ -1366,6 +1367,8 @@ def test_save_last_versioning(tmpdir):
13661367
)
13671368
trainer.fit(model)
13681369
assert {"last.ckpt", "last-v1.ckpt"} == set(os.listdir(tmpdir))
1370+
# 'last.ckpt' is not a symlink since `save_top_k=0` didn't save any other checkpoints to link to
1371+
assert all(not os.path.islink(tmpdir / path) for path in set(os.listdir(tmpdir)))
13691372

13701373

13711374
def test_none_monitor_saves_correct_best_model_path(tmpdir):
@@ -1385,7 +1388,7 @@ def test_last_global_step_saved():
13851388
# this should not save anything
13861389
model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo")
13871390
trainer = Mock()
1388-
monitor_candidates = {"foo": 123}
1391+
monitor_candidates = {"foo": torch.tensor(123)}
13891392
model_checkpoint._save_topk_checkpoint(trainer, monitor_candidates)
13901393
model_checkpoint._save_last_checkpoint(trainer, monitor_candidates)
13911394
assert model_checkpoint._last_global_step_saved == 0

tests/tests_pytorch/models/test_restore.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,11 @@ def get_trainer_args():
311311
"best_k_models",
312312
"kth_best_model_path",
313313
"kth_value",
314-
"last_model_path",
315314
):
316-
assert getattr(before, attribute) == getattr(after, attribute)
315+
assert getattr(before, attribute) == getattr(after, attribute), f"{attribute}"
316+
# `before.last_model_path` is a symlink pointing to a checkpoint saved before that symlink was created,
317+
# hence reloading that checkpoint will restore `after.last_model_path = ""`
318+
assert after.last_model_path == ""
317319

318320

319321
@RunIf(sklearn=True)

tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_checkpoint_plugin_called(tmpdir):
4747
model = BoringModel()
4848
trainer = Trainer(
4949
default_root_dir=tmpdir,
50+
accelerator="cpu",
5051
strategy=SingleDeviceStrategy("cpu", checkpoint_io=checkpoint_plugin),
5152
callbacks=ck,
5253
max_epochs=2,
@@ -60,7 +61,7 @@ def test_checkpoint_plugin_called(tmpdir):
6061
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"}
6162
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt"
6263
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
63-
assert checkpoint_plugin.save_checkpoint.call_count == 4
64+
assert checkpoint_plugin.save_checkpoint.call_count == 2
6465
assert checkpoint_plugin.remove_checkpoint.call_count == 1
6566

6667
trainer.test(model, ckpt_path=ck.last_model_path)
@@ -72,6 +73,7 @@ def test_checkpoint_plugin_called(tmpdir):
7273
model = BoringModel()
7374
trainer = Trainer(
7475
default_root_dir=tmpdir,
76+
accelerator="cpu",
7577
strategy=SingleDeviceStrategy("cpu"),
7678
plugins=[checkpoint_plugin],
7779
callbacks=ck,
@@ -86,7 +88,7 @@ def test_checkpoint_plugin_called(tmpdir):
8688
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"}
8789
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt"
8890
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt"
89-
assert checkpoint_plugin.save_checkpoint.call_count == 4
91+
assert checkpoint_plugin.save_checkpoint.call_count == 2
9092
assert checkpoint_plugin.remove_checkpoint.call_count == 1
9193

9294
trainer.test(model, ckpt_path=ck.last_model_path)

0 commit comments

Comments
 (0)