Skip to content

Commit 97121a5

Browse files
krshrimalicarmoccaBordarohitgr7
authored
Prevent last checkpoint being deleted after resumed training with changed dirpath (#12225)
Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent abe795e commit 97121a5

File tree

3 files changed

+43
-5
lines changed

3 files changed

+43
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
819819

820820
### Fixed
821821

822+
- Fixed an issue where `ModelCheckpoint` could delete last checkpoint from the old directory when `dirpath` has changed during resumed training ([#12225](https://github.com/PyTorchLightning/pytorch-lightning/pull/12225))
823+
824+
822825
- Fixed an issue where `ModelCheckpoint` could delete older checkpoints when `dirpath` has changed during resumed training ([#12045](https://github.com/PyTorchLightning/pytorch-lightning/pull/12045))
823826

824827

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class ModelCheckpoint(Callback):
147147
then you should create multiple ``ModelCheckpoint`` callbacks.
148148
149149
If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
150-
only ``last_model_path`` and ``best_model_path`` will be reloaded and a warning will be issued.
150+
only ``best_model_path`` will be reloaded and a warning will be issued.
151151
152152
Raises:
153153
MisconfigurationException:
@@ -337,13 +337,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
337337
self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path)
338338
self.kth_value = state_dict.get("kth_value", self.kth_value)
339339
self.best_k_models = state_dict.get("best_k_models", self.best_k_models)
340+
self.last_model_path = state_dict.get("last_model_path", self.last_model_path)
340341
else:
341342
warnings.warn(
342343
f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r},"
343-
" therefore `best_model_score`, `kth_best_model_path`, `kth_value` and `best_k_models`"
344-
" won't be reloaded. Only `last_model_path` and `best_model_path` will be reloaded."
344+
" therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and"
345+
" `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded."
345346
)
346-
self.last_model_path = state_dict.get("last_model_path", self.last_model_path)
347+
347348
self.best_model_path = state_dict["best_model_path"]
348349

349350
def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover

tests/checkpointing/test_model_checkpoint.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,8 +1192,8 @@ def make_assertions(cb_restore, written_ckpt):
11921192
"kth_best_model_path": False,
11931193
"kth_value": False,
11941194
"best_k_models": False,
1195+
"last_model_path": False,
11951196
"best_model_path": True,
1196-
"last_model_path": True,
11971197
}
11981198
for key, should_match in expected_keys.items():
11991199
if should_match:
@@ -1245,6 +1245,40 @@ def on_load_checkpoint(self, *args, **kwargs):
12451245
make_assertions(cb_restore, written_ckpt)
12461246

12471247

1248+
def test_resume_training_preserves_old_ckpt_last(tmpdir):
1249+
"""Ensures that the last saved checkpoint is not deleted from the previous folder when training is resumed from
1250+
the old checkpoint."""
1251+
model = BoringModel()
1252+
trainer_kwargs = {
1253+
"default_root_dir": tmpdir,
1254+
"max_epochs": 1,
1255+
"limit_train_batches": 3,
1256+
"limit_val_batches": 0,
1257+
"enable_model_summary": False,
1258+
"logger": False,
1259+
}
1260+
mc_kwargs = {
1261+
"filename": "{step}",
1262+
"monitor": "step",
1263+
"mode": "max",
1264+
"save_last": True,
1265+
"save_top_k": 2,
1266+
"every_n_train_steps": 1,
1267+
}
1268+
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
1269+
trainer.fit(model)
1270+
# Make sure that the last checkpoint file exists in the dirpath passed (`tmpdir`)
1271+
assert set(os.listdir(tmpdir / "checkpoints")) == {"last.ckpt", "step=2.ckpt", "step=3.ckpt"}
1272+
1273+
# Training it for 2 epochs for extra surety, that nothing gets deleted after multiple epochs
1274+
trainer_kwargs["max_epochs"] += 1
1275+
mc_kwargs["dirpath"] = f"{tmpdir}/new"
1276+
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
1277+
trainer.fit(model, ckpt_path=f"{tmpdir}/checkpoints/step=2.ckpt")
1278+
# Ensure that the file is not deleted from the old folder
1279+
assert os.path.isfile(f"{tmpdir}/checkpoints/last.ckpt")
1280+
1281+
12481282
def test_save_last_saves_correct_last_model_path(tmpdir):
12491283
mc = ModelCheckpoint(dirpath=tmpdir, save_last=True)
12501284
mc.CHECKPOINT_NAME_LAST = "{foo}-last"

0 commit comments

Comments
 (0)