Skip to content

Commit ff9c3f9

Browse files
awaelchlijustusschockpre-commit-ci[bot]carmocca
authored andcommitted
Add required states for resumed ModelCheckpoint GC (#10995)
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
1 parent 0a635d9 commit ff9c3f9

File tree

4 files changed

+53
-2
lines changed

4 files changed

+53
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818
### Changed
1919

2020
- DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655))
21+
- The `ModelCheckpoint` callback now saves and restores attributes `best_k_models`, `kth_best_model_path`, `kth_value`, and `last_model_path` ([#10995](https://github.com/PyTorchLightning/pytorch-lightning/pull/10995))
22+
2123

2224
## [1.5.6] - 2021-12-15
2325

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,13 +357,21 @@ def on_save_checkpoint(
357357
"best_model_path": self.best_model_path,
358358
"current_score": self.current_score,
359359
"dirpath": self.dirpath,
360+
"best_k_models": self.best_k_models,
361+
"kth_best_model_path": self.kth_best_model_path,
362+
"kth_value": self.kth_value,
363+
"last_model_path": self.last_model_path,
360364
}
361365

362366
def on_load_checkpoint(
363367
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
364368
) -> None:
365369
self.best_model_score = callback_state["best_model_score"]
366370
self.best_model_path = callback_state["best_model_path"]
371+
self.best_k_models = callback_state.get("best_k_models", self.best_k_models)
372+
self.kth_best_model_path = callback_state.get("kth_best_model_path", self.kth_best_model_path)
373+
self.kth_value = callback_state.get("kth_value", self.kth_value)
374+
self.last_model_path = callback_state.get("last_model_path", self.last_model_path)
367375

368376
def save_checkpoint(self, trainer: "pl.Trainer") -> None:
369377
"""Performs the main logic around saving a checkpoint.

tests/checkpointing/test_model_checkpoint.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,3 +1206,37 @@ def test_check_val_every_n_epochs_top_k_integration(tmpdir):
12061206
)
12071207
trainer.fit(model)
12081208
assert set(os.listdir(tmpdir)) == {"epoch=1.ckpt", "epoch=3.ckpt"}
1209+
1210+
1211+
def test_model_checkpoint_saveload_ckpt(tmpdir):
1212+
ckpt = {
1213+
"monitor": "random_value",
1214+
"best_model_path": "epoch=10-step=1436.ckpt",
1215+
"best_model_score": torch.tensor(2.246),
1216+
"current_score": torch.tensor(1.5),
1217+
"dirpath": tmpdir,
1218+
"best_k_models": {"epoch=10-step=1436.ckpt": torch.tensor(2.246)},
1219+
"kth_best_model_path": "epoch=10-step=1436.ckpt",
1220+
"kth_value": torch.tensor(2.246),
1221+
"last_model_path": "last2245.ckpt",
1222+
}
1223+
1224+
# test on_save_checkpoint
1225+
cb_write = ModelCheckpoint(dirpath=tmpdir, monitor="random_value", save_top_k=-1, save_last=True)
1226+
for key, val in ckpt.items():
1227+
setattr(cb_write, key, val)
1228+
written_ckpt = cb_write.on_save_checkpoint("", "", "")
1229+
for state in ckpt:
1230+
assert ckpt[state] == written_ckpt[state]
1231+
1232+
# test on_load_checkpoint
1233+
# Note: "current_score", "dirpath" and "monitor" are currently not restored by on_load_checkpoint.
1234+
# We therefore set "dirpath" and "monitor" to something different than for ckpt/cb_write so we can assert them.
1235+
# "current_score" is left as initialized, i.e. None, and can therefore also be asserted
1236+
cb_restore = ModelCheckpoint(dirpath=tmpdir + "restore", monitor=None, save_top_k=-1, save_last=True)
1237+
cb_restore.on_load_checkpoint("", "", written_ckpt)
1238+
for key, val in written_ckpt.items():
1239+
if key not in ("current_score", "dirpath", "monitor"):
1240+
assert getattr(cb_restore, key) == val
1241+
else:
1242+
assert getattr(cb_restore, key) != val

tests/models/test_restore.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,15 @@ def get_trainer_args():
266266

267267
for before, after in zip(callbacks_before_resume, callback_capture.callbacks):
268268
if isinstance(before, ModelCheckpoint):
269-
assert before.best_model_path == after.best_model_path
270-
assert before.best_model_score == after.best_model_score
269+
for attribute in (
270+
"best_model_path",
271+
"best_model_score",
272+
"best_k_models",
273+
"kth_best_model_path",
274+
"kth_value",
275+
"last_model_path",
276+
):
277+
assert getattr(before, attribute) == getattr(after, attribute)
271278

272279

273280
def test_callbacks_references_fit_ckpt_path(tmpdir):

0 commit comments

Comments
 (0)