Skip to content
44 changes: 31 additions & 13 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,29 @@
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Set, Union
from typing import Any, Dict, Literal, Optional, Set, Union, cast
from weakref import proxy

import torch
import yaml
from torch import Tensor
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import Checkpoint
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.types import STEP_OUTPUT
import pytorch_lightning as pl
from lightning_fabric.utilities.cloud_io import (
_is_dir,
_is_local_file_protocol,
get_filesystem,
)
from lightning_fabric.utilities.types import _PATH
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import (
WarningCache,
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities.types import STEP_OUTPUT

log = logging.getLogger(__name__)
warning_cache = WarningCache()
Expand Down Expand Up @@ -241,9 +249,10 @@ def __init__(
self._last_global_step_saved = 0 # no need to save when no steps were taken
self._last_time_checked: Optional[float] = None
self.current_score: Optional[Tensor] = None
self.best_k_models: Dict[str, Tensor] = {}
self.best_k_models: Dict[str, Dict[str, Tensor | Dict[str, Tensor]]] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may be easier to read
Dict[str, Dict[str, Tensor]] | Dict[str, Dict[str, Dict[str, Tensor]]]
but ultimately we'd be better off defining a type alias

more importantly, we need to avoid breaking backward compatibility here
so whatever code relies on best_k_models being Dict[str, Tensor] today needs to keep working

I suggest we just limit ourselves to track best_model_metrics and not mess with best_k_models, or use a separate private attribute

self.kth_best_model_path = ""
self.best_model_score: Optional[Tensor] = None
self.best_model_metrics: Optional[Dict[str, Tensor]] = None
self.best_model_path = ""
self.last_model_path = ""
self._last_checkpoint_saved = ""
Expand Down Expand Up @@ -339,6 +348,7 @@ def state_dict(self) -> Dict[str, Any]:
return {
"monitor": self.monitor,
"best_model_score": self.best_model_score,
"best_model_metrics": self.best_model_metrics,
"best_model_path": self.best_model_path,
"current_score": self.current_score,
"dirpath": self.dirpath,
Expand All @@ -354,6 +364,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

if self.dirpath == dirpath_from_ckpt:
self.best_model_score = state_dict["best_model_score"]
self.best_model_metrics = state_dict["best_model_metrics"]
self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path)
self.kth_value = state_dict.get("kth_value", self.kth_value)
self.best_k_models = state_dict.get("best_k_models", self.best_k_models)
Expand Down Expand Up @@ -523,7 +534,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
return True

monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
should_update_best_and_save = monitor_op(
current, cast(Tensor, self.best_k_models[self.kth_best_model_path]["score"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will stay as in the original if we avoid changing best_k_models

)

# If using multiple devices, make sure all processes are unanimous on the decision.
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))
Expand Down Expand Up @@ -735,17 +748,22 @@ def _update_best_and_save(

# save the current score
self.current_score = current
self.best_k_models[filepath] = current
self.best_k_models[filepath] = {
"score": current,
"metrics": monitor_candidates,
}

if len(self.best_k_models) == k:
# monitor dict has reached k elements
_op = max if self.mode == "min" else min
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.kth_value = self.best_k_models[self.kth_best_model_path]
self.kth_model_metrics = self.best_k_models[self.kth_best_model_path]["metrics"]

_op = min if self.mode == "min" else max
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.best_model_score = self.best_k_models[self.best_model_path]
self.best_model_score = self.best_k_models[self.best_model_path]["score"]
self.best_model_metrics = self.best_k_models[self.best_model_path]["metrics"]

if self.verbose:
epoch = monitor_candidates["epoch"]
Expand All @@ -762,7 +780,7 @@ def _update_best_and_save(
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
file."""
best_k = {k: v.item() for k, v in self.best_k_models.items()}
best_k = {k: v["score"].item() for k, v in self.best_k_models.items()} # type: ignore[arg-type]
if filepath is None:
assert self.dirpath
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
Expand Down
24 changes: 15 additions & 9 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@
from unittest.mock import Mock, call, patch

import cloudpickle
import lightning.pytorch as pl
import pytest
import torch
import yaml
from jsonargparse import ArgumentParser
from tests_pytorch.helpers.runif import RunIf
from torch import optim

import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.pytorch import Trainer, seed_everything
Expand All @@ -39,9 +42,6 @@
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
from torch import optim

from tests_pytorch.helpers.runif import RunIf

if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf
Expand Down Expand Up @@ -706,6 +706,7 @@ def test_model_checkpoint_save_last_none_monitor(tmp_path, caplog):
assert checkpoint_callback.best_model_path == str(tmp_path / "epoch=1-step=20.ckpt")
assert checkpoint_callback.last_model_path == str(tmp_path / "last.ckpt")
assert checkpoint_callback.best_model_score is None
assert checkpoint_callback.best_model_metrics is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add tests that exercise the new code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gonzachiar, you have only test for the actual scenarion when the newly added attribute is empty, but we need to add a test to validate that your addition works as expected

assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ""

Expand Down Expand Up @@ -812,6 +813,7 @@ def test_model_checkpoint_topk_zero(tmp_path):
assert checkpoint_callback.monitor is None
assert checkpoint_callback.best_model_path == ""
assert checkpoint_callback.best_model_score is None
assert checkpoint_callback.best_model_metrics is None
assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ""
# check that only the last ckpt was created
Expand Down Expand Up @@ -891,11 +893,13 @@ def test_default_checkpoint_behavior(tmp_path):
assert len(results) == 1
save_dir = tmp_path / "checkpoints"
save_weights_only = trainer.checkpoint_callback.save_weights_only
save_mock.assert_has_calls([
call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only),
call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only),
call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only),
])
save_mock.assert_has_calls(
[
call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only),
call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only),
call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only),
]
)
ckpts = os.listdir(save_dir)
assert len(ckpts) == 1
assert ckpts[0] == "epoch=2-step=15.ckpt"
Expand Down Expand Up @@ -1479,6 +1483,8 @@ def test_save_last_versioning(tmp_path):
assert all(not os.path.islink(tmp_path / path) for path in set(os.listdir(tmp_path)))




def test_none_monitor_saves_correct_best_model_path(tmp_path):
mc = ModelCheckpoint(dirpath=tmp_path, monitor=None)
trainer = Trainer(callbacks=mc)
Expand Down