-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add best_k_metrics
parameter to the ModelCheckpoint
#20457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
8df37c2
32db7d4
0b3322d
74cf6a7
a214154
835538c
80c4bb6
821bfef
247809f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,21 +27,29 @@ | |
from copy import deepcopy | ||
from datetime import timedelta | ||
from pathlib import Path | ||
from typing import Any, Dict, Literal, Optional, Set, Union | ||
from typing import Any, Dict, Literal, Optional, Set, Union, cast | ||
from weakref import proxy | ||
|
||
import torch | ||
import yaml | ||
from torch import Tensor | ||
from typing_extensions import override | ||
|
||
import lightning.pytorch as pl | ||
from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem | ||
from lightning.fabric.utilities.types import _PATH | ||
from lightning.pytorch.callbacks import Checkpoint | ||
from lightning.pytorch.utilities.exceptions import MisconfigurationException | ||
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn | ||
from lightning.pytorch.utilities.types import STEP_OUTPUT | ||
import pytorch_lightning as pl | ||
from lightning_fabric.utilities.cloud_io import ( | ||
_is_dir, | ||
gonzachiar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_is_local_file_protocol, | ||
get_filesystem, | ||
) | ||
from lightning_fabric.utilities.types import _PATH | ||
from pytorch_lightning.callbacks import Checkpoint | ||
gonzachiar marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities.rank_zero import ( | ||
WarningCache, | ||
rank_zero_info, | ||
rank_zero_warn, | ||
) | ||
from pytorch_lightning.utilities.types import STEP_OUTPUT | ||
|
||
log = logging.getLogger(__name__) | ||
warning_cache = WarningCache() | ||
|
@@ -241,9 +249,10 @@ def __init__( | |
self._last_global_step_saved = 0 # no need to save when no steps were taken | ||
self._last_time_checked: Optional[float] = None | ||
self.current_score: Optional[Tensor] = None | ||
self.best_k_models: Dict[str, Tensor] = {} | ||
self.best_k_models: Dict[str, Dict[str, Tensor | Dict[str, Tensor]]] = {} | ||
|
||
self.kth_best_model_path = "" | ||
self.best_model_score: Optional[Tensor] = None | ||
self.best_model_metrics: Optional[Dict[str, Tensor]] = None | ||
self.best_model_path = "" | ||
self.last_model_path = "" | ||
self._last_checkpoint_saved = "" | ||
|
@@ -339,6 +348,7 @@ def state_dict(self) -> Dict[str, Any]: | |
return { | ||
"monitor": self.monitor, | ||
"best_model_score": self.best_model_score, | ||
"best_model_metrics": self.best_model_metrics, | ||
"best_model_path": self.best_model_path, | ||
"current_score": self.current_score, | ||
"dirpath": self.dirpath, | ||
|
@@ -354,6 +364,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
|
||
if self.dirpath == dirpath_from_ckpt: | ||
self.best_model_score = state_dict["best_model_score"] | ||
self.best_model_metrics = state_dict["best_model_metrics"] | ||
self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path) | ||
self.kth_value = state_dict.get("kth_value", self.kth_value) | ||
self.best_k_models = state_dict.get("best_k_models", self.best_k_models) | ||
|
@@ -523,7 +534,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = | |
return True | ||
|
||
monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] | ||
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) | ||
should_update_best_and_save = monitor_op( | ||
current, cast(Tensor, self.best_k_models[self.kth_best_model_path]["score"]) | ||
|
||
) | ||
|
||
# If using multiple devices, make sure all processes are unanimous on the decision. | ||
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save)) | ||
|
@@ -735,17 +748,22 @@ def _update_best_and_save( | |
|
||
# save the current score | ||
self.current_score = current | ||
self.best_k_models[filepath] = current | ||
self.best_k_models[filepath] = { | ||
"score": current, | ||
"metrics": monitor_candidates, | ||
} | ||
|
||
if len(self.best_k_models) == k: | ||
# monitor dict has reached k elements | ||
_op = max if self.mode == "min" else min | ||
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] | ||
self.kth_value = self.best_k_models[self.kth_best_model_path] | ||
self.kth_model_metrics = self.best_k_models[self.kth_best_model_path]["metrics"] | ||
|
||
_op = min if self.mode == "min" else max | ||
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] | ||
self.best_model_score = self.best_k_models[self.best_model_path] | ||
self.best_model_score = self.best_k_models[self.best_model_path]["score"] | ||
self.best_model_metrics = self.best_k_models[self.best_model_path]["metrics"] | ||
|
||
if self.verbose: | ||
epoch = monitor_candidates["epoch"] | ||
|
@@ -762,7 +780,7 @@ def _update_best_and_save( | |
def to_yaml(self, filepath: Optional[_PATH] = None) -> None: | ||
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML | ||
file.""" | ||
best_k = {k: v.item() for k, v in self.best_k_models.items()} | ||
best_k = {k: v["score"].item() for k, v in self.best_k_models.items()} # type: ignore[arg-type] | ||
if filepath is None: | ||
assert self.dirpath | ||
filepath = os.path.join(self.dirpath, "best_k_models.yaml") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,11 +26,14 @@ | |
from unittest.mock import Mock, call, patch | ||
|
||
import cloudpickle | ||
import lightning.pytorch as pl | ||
import pytest | ||
import torch | ||
import yaml | ||
from jsonargparse import ArgumentParser | ||
from tests_pytorch.helpers.runif import RunIf | ||
from torch import optim | ||
|
||
import lightning.pytorch as pl | ||
from lightning.fabric.utilities.cloud_io import _load as pl_load | ||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 | ||
from lightning.pytorch import Trainer, seed_everything | ||
|
@@ -39,9 +42,6 @@ | |
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger | ||
from lightning.pytorch.utilities.exceptions import MisconfigurationException | ||
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE | ||
from torch import optim | ||
|
||
from tests_pytorch.helpers.runif import RunIf | ||
|
||
if _OMEGACONF_AVAILABLE: | ||
from omegaconf import Container, OmegaConf | ||
|
@@ -706,6 +706,7 @@ def test_model_checkpoint_save_last_none_monitor(tmp_path, caplog): | |
assert checkpoint_callback.best_model_path == str(tmp_path / "epoch=1-step=20.ckpt") | ||
assert checkpoint_callback.last_model_path == str(tmp_path / "last.ckpt") | ||
assert checkpoint_callback.best_model_score is None | ||
assert checkpoint_callback.best_model_metrics is None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to add tests that exercise the new code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gonzachiar, you have only test for the actual scenarion when the newly added attribute is empty, but we need to add a test to validate that your addition works as expected |
||
assert checkpoint_callback.best_k_models == {} | ||
assert checkpoint_callback.kth_best_model_path == "" | ||
|
||
|
@@ -812,6 +813,7 @@ def test_model_checkpoint_topk_zero(tmp_path): | |
assert checkpoint_callback.monitor is None | ||
assert checkpoint_callback.best_model_path == "" | ||
assert checkpoint_callback.best_model_score is None | ||
assert checkpoint_callback.best_model_metrics is None | ||
assert checkpoint_callback.best_k_models == {} | ||
assert checkpoint_callback.kth_best_model_path == "" | ||
# check that only the last ckpt was created | ||
|
@@ -891,11 +893,13 @@ def test_default_checkpoint_behavior(tmp_path): | |
assert len(results) == 1 | ||
save_dir = tmp_path / "checkpoints" | ||
save_weights_only = trainer.checkpoint_callback.save_weights_only | ||
save_mock.assert_has_calls([ | ||
call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only), | ||
call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only), | ||
call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only), | ||
]) | ||
save_mock.assert_has_calls( | ||
[ | ||
call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only), | ||
call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only), | ||
call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only), | ||
] | ||
) | ||
ckpts = os.listdir(save_dir) | ||
assert len(ckpts) == 1 | ||
assert ckpts[0] == "epoch=2-step=15.ckpt" | ||
|
@@ -1479,6 +1483,8 @@ def test_save_last_versioning(tmp_path): | |
assert all(not os.path.islink(tmp_path / path) for path in set(os.listdir(tmp_path))) | ||
|
||
|
||
|
||
|
||
def test_none_monitor_saves_correct_best_model_path(tmp_path): | ||
mc = ModelCheckpoint(dirpath=tmp_path, monitor=None) | ||
trainer = Trainer(callbacks=mc) | ||
|
Uh oh!
There was an error while loading. Please reload this page.