Skip to content

Commit 8df37c2

Browse files
committed
add best_k_metrics parameter
1 parent 5be58f6 commit 8df37c2

File tree

2 files changed

+46
-22
lines changed

2 files changed

+46
-22
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,29 @@
2727
from copy import deepcopy
2828
from datetime import timedelta
2929
from pathlib import Path
30-
from typing import Any, Dict, Literal, Optional, Set, Union
30+
from typing import Any, Dict, Literal, Optional, Set, Union, cast
3131
from weakref import proxy
3232

3333
import torch
3434
import yaml
3535
from torch import Tensor
3636
from typing_extensions import override
3737

38-
import lightning.pytorch as pl
39-
from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem
40-
from lightning.fabric.utilities.types import _PATH
41-
from lightning.pytorch.callbacks import Checkpoint
42-
from lightning.pytorch.utilities.exceptions import MisconfigurationException
43-
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn
44-
from lightning.pytorch.utilities.types import STEP_OUTPUT
38+
import pytorch_lightning as pl
39+
from lightning_fabric.utilities.cloud_io import (
40+
_is_dir,
41+
_is_local_file_protocol,
42+
get_filesystem,
43+
)
44+
from lightning_fabric.utilities.types import _PATH
45+
from pytorch_lightning.callbacks import Checkpoint
46+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
47+
from pytorch_lightning.utilities.rank_zero import (
48+
WarningCache,
49+
rank_zero_info,
50+
rank_zero_warn,
51+
)
52+
from pytorch_lightning.utilities.types import STEP_OUTPUT
4553

4654
log = logging.getLogger(__name__)
4755
warning_cache = WarningCache()
@@ -241,9 +249,10 @@ def __init__(
241249
self._last_global_step_saved = 0 # no need to save when no steps were taken
242250
self._last_time_checked: Optional[float] = None
243251
self.current_score: Optional[Tensor] = None
244-
self.best_k_models: Dict[str, Tensor] = {}
252+
self.best_k_models: Dict[str, Dict[str, Tensor | Dict[str, Tensor]]] = {}
245253
self.kth_best_model_path = ""
246254
self.best_model_score: Optional[Tensor] = None
255+
self.best_model_metrics: Optional[Dict[str, Tensor]] = None
247256
self.best_model_path = ""
248257
self.last_model_path = ""
249258
self._last_checkpoint_saved = ""
@@ -339,6 +348,7 @@ def state_dict(self) -> Dict[str, Any]:
339348
return {
340349
"monitor": self.monitor,
341350
"best_model_score": self.best_model_score,
351+
"best_model_metrics": self.best_model_metrics,
342352
"best_model_path": self.best_model_path,
343353
"current_score": self.current_score,
344354
"dirpath": self.dirpath,
@@ -354,6 +364,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
354364

355365
if self.dirpath == dirpath_from_ckpt:
356366
self.best_model_score = state_dict["best_model_score"]
367+
self.best_model_metrics = state_dict["best_model_metrics"]
357368
self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path)
358369
self.kth_value = state_dict.get("kth_value", self.kth_value)
359370
self.best_k_models = state_dict.get("best_k_models", self.best_k_models)
@@ -523,7 +534,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
523534
return True
524535

525536
monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
526-
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
537+
should_update_best_and_save = monitor_op(
538+
current, cast(Tensor, self.best_k_models[self.kth_best_model_path]["score"])
539+
)
527540

528541
# If using multiple devices, make sure all processes are unanimous on the decision.
529542
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))
@@ -735,17 +748,22 @@ def _update_best_and_save(
735748

736749
# save the current score
737750
self.current_score = current
738-
self.best_k_models[filepath] = current
751+
self.best_k_models[filepath] = {
752+
"score": current,
753+
"metrics": monitor_candidates,
754+
}
739755

740756
if len(self.best_k_models) == k:
741757
# monitor dict has reached k elements
742758
_op = max if self.mode == "min" else min
743759
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
744760
self.kth_value = self.best_k_models[self.kth_best_model_path]
761+
self.kth_model_metrics = self.best_k_models[self.kth_best_model_path]["metrics"]
745762

746763
_op = min if self.mode == "min" else max
747764
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
748-
self.best_model_score = self.best_k_models[self.best_model_path]
765+
self.best_model_score = self.best_k_models[self.best_model_path]["score"]
766+
self.best_model_metrics = self.best_k_models[self.best_model_path]["metrics"]
749767

750768
if self.verbose:
751769
epoch = monitor_candidates["epoch"]
@@ -762,7 +780,7 @@ def _update_best_and_save(
762780
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
763781
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
764782
file."""
765-
best_k = {k: v.item() for k, v in self.best_k_models.items()}
783+
best_k = {k: v["score"].item() for k, v in self.best_k_models.items()} # type: ignore[arg-type]
766784
if filepath is None:
767785
assert self.dirpath
768786
filepath = os.path.join(self.dirpath, "best_k_models.yaml")

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@
2626
from unittest.mock import Mock, call, patch
2727

2828
import cloudpickle
29-
import lightning.pytorch as pl
3029
import pytest
3130
import torch
3231
import yaml
3332
from jsonargparse import ArgumentParser
33+
from tests_pytorch.helpers.runif import RunIf
34+
from torch import optim
35+
36+
import lightning.pytorch as pl
3437
from lightning.fabric.utilities.cloud_io import _load as pl_load
3538
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
3639
from lightning.pytorch import Trainer, seed_everything
@@ -39,9 +42,6 @@
3942
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
4043
from lightning.pytorch.utilities.exceptions import MisconfigurationException
4144
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
42-
from torch import optim
43-
44-
from tests_pytorch.helpers.runif import RunIf
4545

4646
if _OMEGACONF_AVAILABLE:
4747
from omegaconf import Container, OmegaConf
@@ -706,6 +706,7 @@ def test_model_checkpoint_save_last_none_monitor(tmp_path, caplog):
706706
assert checkpoint_callback.best_model_path == str(tmp_path / "epoch=1-step=20.ckpt")
707707
assert checkpoint_callback.last_model_path == str(tmp_path / "last.ckpt")
708708
assert checkpoint_callback.best_model_score is None
709+
assert checkpoint_callback.best_model_metrics is None
709710
assert checkpoint_callback.best_k_models == {}
710711
assert checkpoint_callback.kth_best_model_path == ""
711712

@@ -812,6 +813,7 @@ def test_model_checkpoint_topk_zero(tmp_path):
812813
assert checkpoint_callback.monitor is None
813814
assert checkpoint_callback.best_model_path == ""
814815
assert checkpoint_callback.best_model_score is None
816+
assert checkpoint_callback.best_model_metrics is None
815817
assert checkpoint_callback.best_k_models == {}
816818
assert checkpoint_callback.kth_best_model_path == ""
817819
# check that only the last ckpt was created
@@ -891,11 +893,13 @@ def test_default_checkpoint_behavior(tmp_path):
891893
assert len(results) == 1
892894
save_dir = tmp_path / "checkpoints"
893895
save_weights_only = trainer.checkpoint_callback.save_weights_only
894-
save_mock.assert_has_calls([
895-
call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only),
896-
call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only),
897-
call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only),
898-
])
896+
save_mock.assert_has_calls(
897+
[
898+
call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only),
899+
call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only),
900+
call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only),
901+
]
902+
)
899903
ckpts = os.listdir(save_dir)
900904
assert len(ckpts) == 1
901905
assert ckpts[0] == "epoch=2-step=15.ckpt"
@@ -1479,6 +1483,8 @@ def test_save_last_versioning(tmp_path):
14791483
assert all(not os.path.islink(tmp_path / path) for path in set(os.listdir(tmp_path)))
14801484

14811485

1486+
1487+
14821488
def test_none_monitor_saves_correct_best_model_path(tmp_path):
14831489
mc = ModelCheckpoint(dirpath=tmp_path, monitor=None)
14841490
trainer = Trainer(callbacks=mc)

0 commit comments

Comments
 (0)