Skip to content

Commit 482da0a

Browse files
authored
Fix ModelCheckpoint alternating between versioned and unversioned file (#19064)
1 parent e30401a commit 482da0a

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5959
- Fixed handling checkpoint dirpath suffix in NeptuneLogger ([#18863](https://github.com/Lightning-AI/lightning/pull/18863))
6060

6161

62+
- Fixed an edge case where `ModelCheckpoint` would alternate between versioned and unversioned filename ([#19064](https://github.com/Lightning-AI/lightning/pull/19064))
63+
6264

6365
## [2.1.2] - 2023-11-15
6466

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Di
703703
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")
704704

705705
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
706-
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
706+
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, self.best_model_path)
707707
# set the best model path before saving because it will be part of the state.
708708
previous, self.best_model_path = self.best_model_path, filepath
709709
self._save_checkpoint(trainer, filepath)
@@ -773,7 +773,7 @@ def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, curren
773773
"""Checks if the previous checkpoint should be deleted.
774774
775775
A checkpoint won't be deleted if any of the cases apply:
776-
- The previous checkpoint is the same as the current checkpoint
776+
- The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new)
777777
- The previous checkpoint is not in the current checkpoint directory and the filesystem is local
778778
- The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local
779779

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,28 @@ def test_none_monitor_top_k(tmpdir):
583583
ModelCheckpoint(dirpath=tmpdir, save_top_k=1)
584584

585585

586+
def test_none_monitor_not_alternating(tmp_path):
587+
"""Regression test for the case where the callback saved alternating `model.ckpt` and `model-v1.ckpt` files."""
588+
589+
class ListDirModel(BoringModel):
590+
def on_train_epoch_start(self):
591+
if self.current_epoch > 0:
592+
assert os.listdir(tmp_path) == ["model.ckpt"]
593+
594+
model = ListDirModel()
595+
model_checkpoint = ModelCheckpoint(dirpath=tmp_path, monitor=None, save_top_k=1, filename="model")
596+
trainer = Trainer(
597+
callbacks=model_checkpoint,
598+
limit_train_batches=1,
599+
limit_val_batches=0,
600+
max_epochs=3,
601+
enable_model_summary=False,
602+
enable_progress_bar=False,
603+
logger=False,
604+
)
605+
trainer.fit(model)
606+
607+
586608
def test_invalid_every_n_epochs(tmpdir):
587609
"""Make sure that a MisconfigurationException is raised for a negative every_n_epochs argument."""
588610
with pytest.raises(MisconfigurationException, match=r".*Must be >= 0"):

0 commit comments

Comments
 (0)