Skip to content

Commit 821bfef

Browse files
committed
fix: revert changes on best_k_models
1 parent 80c4bb6 commit 821bfef

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from copy import deepcopy
2828
from datetime import timedelta
2929
from pathlib import Path
30-
from typing import Any, Literal, Optional, Union, cast
30+
from typing import Any, Literal, Optional, Union
3131
from weakref import proxy
3232

3333
import torch
@@ -249,7 +249,7 @@ def __init__(
249249
self._last_global_step_saved = 0 # no need to save when no steps were taken
250250
self._last_time_checked: Optional[float] = None
251251
self.current_score: Optional[Tensor] = None
252-
self.best_k_models: dict[str, dict[str, Tensor | dict[str, Tensor]]] = {}
252+
self.best_k_models: dict[str, Tensor] = {}
253253
self.kth_best_model_path = ""
254254
self.best_model_score: Optional[Tensor] = None
255255
self.best_model_metrics: Optional[dict[str, Tensor]] = None
@@ -372,8 +372,8 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
372372
else:
373373
warnings.warn(
374374
f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r},"
375-
" therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and"
376-
" `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded."
375+
" therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path`,"
376+
" `best_k_models` and `best_model_metrics` won't be reloaded. Only `best_model_path` will be reloaded."
377377
)
378378

379379
self.best_model_path = state_dict["best_model_path"]
@@ -534,9 +534,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
534534
return True
535535

536536
monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
537-
should_update_best_and_save = monitor_op(
538-
current, cast(Tensor, self.best_k_models[self.kth_best_model_path]["score"])
539-
)
537+
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
540538

541539
# If using multiple devices, make sure all processes are unanimous on the decision.
542540
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))
@@ -748,22 +746,19 @@ def _update_best_and_save(
748746

749747
# save the current score
750748
self.current_score = current
751-
self.best_k_models[filepath] = {
752-
"score": current,
753-
"metrics": monitor_candidates,
754-
}
749+
self.best_k_models[filepath] = current
755750

756751
if len(self.best_k_models) == k:
757752
# monitor dict has reached k elements
758753
_op = max if self.mode == "min" else min
759754
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
760755
self.kth_value = self.best_k_models[self.kth_best_model_path]
761-
self.kth_model_metrics = self.best_k_models[self.kth_best_model_path]["metrics"]
762756

763757
_op = min if self.mode == "min" else max
764758
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
765-
self.best_model_score = self.best_k_models[self.best_model_path]["score"]
766-
self.best_model_metrics = self.best_k_models[self.best_model_path]["metrics"]
759+
self.best_model_score = self.best_k_models[self.best_model_path]
760+
if self.best_model_path == filepath:
761+
self.best_model_metrics = monitor_candidates
767762

768763
if self.verbose:
769764
epoch = monitor_candidates["epoch"]
@@ -780,7 +775,7 @@ def _update_best_and_save(
780775
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
781776
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
782777
file."""
783-
best_k = {k: v["score"].item() for k, v in self.best_k_models.items()} # type: ignore[arg-type]
778+
best_k = {k: v.item() for k, v in self.best_k_models.items()}
784779
if filepath is None:
785780
assert self.dirpath
786781
filepath = os.path.join(self.dirpath, "best_k_models.yaml")

0 commit comments

Comments
 (0)