Skip to content
23 changes: 18 additions & 5 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,20 @@
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import Checkpoint
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.rank_zero import (
WarningCache,
rank_zero_info,
rank_zero_warn,
)
from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning_fabric.utilities.cloud_io import (
_is_dir,
_is_local_file_protocol,
get_filesystem,
)
from lightning_fabric.utilities.types import _PATH

log = logging.getLogger(__name__)
warning_cache = WarningCache()
Expand Down Expand Up @@ -254,6 +262,7 @@ def __init__(
self.best_k_models: dict[str, Tensor] = {}
self.kth_best_model_path = ""
self.best_model_score: Optional[Tensor] = None
self.best_model_metrics: Optional[dict[str, Tensor]] = None
self.best_model_path = ""
self.last_model_path = ""
self._last_checkpoint_saved = ""
Expand Down Expand Up @@ -349,6 +358,7 @@ def state_dict(self) -> dict[str, Any]:
return {
"monitor": self.monitor,
"best_model_score": self.best_model_score,
"best_model_metrics": self.best_model_metrics,
"best_model_path": self.best_model_path,
"current_score": self.current_score,
"dirpath": self.dirpath,
Expand All @@ -364,15 +374,16 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:

if self.dirpath == dirpath_from_ckpt:
self.best_model_score = state_dict["best_model_score"]
self.best_model_metrics = state_dict["best_model_metrics"]
self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path)
self.kth_value = state_dict.get("kth_value", self.kth_value)
self.best_k_models = state_dict.get("best_k_models", self.best_k_models)
self.last_model_path = state_dict.get("last_model_path", self.last_model_path)
else:
warnings.warn(
f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r},"
" therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and"
" `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded."
" therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path`,"
" `best_k_models` and `best_model_metrics` won't be reloaded. Only `best_model_path` will be reloaded."
)

self.best_model_path = state_dict["best_model_path"]
Expand Down Expand Up @@ -756,6 +767,8 @@ def _update_best_and_save(
_op = min if self.mode == "min" else max
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.best_model_score = self.best_k_models[self.best_model_path]
if self.best_model_path == filepath:
self.best_model_metrics = monitor_candidates

if self.verbose:
epoch = monitor_candidates["epoch"]
Expand Down
5 changes: 3 additions & 2 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from torch import optim
from torch.utils.data.dataloader import DataLoader

import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -703,6 +702,7 @@ def test_model_checkpoint_save_last_none_monitor(tmp_path, caplog):
assert checkpoint_callback.best_model_path == str(tmp_path / "epoch=1-step=20.ckpt")
assert checkpoint_callback.last_model_path == str(tmp_path / "last.ckpt")
assert checkpoint_callback.best_model_score is None
assert checkpoint_callback.best_model_metrics is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add tests that exercise the new code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gonzachiar, you have only test for the actual scenarion when the newly added attribute is empty, but we need to add a test to validate that your addition works as expected

assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ""

Expand Down Expand Up @@ -809,6 +809,7 @@ def test_model_checkpoint_topk_zero(tmp_path):
assert checkpoint_callback.monitor is None
assert checkpoint_callback.best_model_path == ""
assert checkpoint_callback.best_model_score is None
assert checkpoint_callback.best_model_metrics is None
assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ""
# check that only the last ckpt was created
Expand Down Expand Up @@ -1074,7 +1075,7 @@ def assert_checkpoint_log_dir(idx):

# load from checkpoint
trainer_config["logger"] = TensorBoardLogger(tmp_path)
trainer = pl.Trainer(**trainer_config)
trainer = Trainer(**trainer_config)
assert_trainer_init(trainer)

model = ExtendedBoringModel()
Expand Down
Loading