2727from copy import deepcopy
2828from datetime import timedelta
2929from pathlib import Path
30- from typing import Any , Literal , Optional , Union , cast
30+ from typing import Any , Literal , Optional , Union
3131from weakref import proxy
3232
3333import torch
@@ -249,7 +249,7 @@ def __init__(
249249 self ._last_global_step_saved = 0 # no need to save when no steps were taken
250250 self ._last_time_checked : Optional [float ] = None
251251 self .current_score : Optional [Tensor ] = None
252- self .best_k_models : dict [str , dict [ str , Tensor | dict [ str , Tensor ]] ] = {}
252+ self .best_k_models : dict [str , Tensor ] = {}
253253 self .kth_best_model_path = ""
254254 self .best_model_score : Optional [Tensor ] = None
255255 self .best_model_metrics : Optional [dict [str , Tensor ]] = None
@@ -372,8 +372,8 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
372372 else :
373373 warnings .warn (
374374 f"The dirpath has changed from { dirpath_from_ckpt !r} to { self .dirpath !r} ,"
375- " therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and "
376- " `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded."
375+ " therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path`, "
376+ " `best_k_models` and `best_model_metrics` won't be reloaded. Only `best_model_path` will be reloaded."
377377 )
378378
379379 self .best_model_path = state_dict ["best_model_path" ]
@@ -534,9 +534,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
534534 return True
535535
536536 monitor_op = {"min" : torch .lt , "max" : torch .gt }[self .mode ]
537- should_update_best_and_save = monitor_op (
538- current , cast (Tensor , self .best_k_models [self .kth_best_model_path ]["score" ])
539- )
537+ should_update_best_and_save = monitor_op (current , self .best_k_models [self .kth_best_model_path ])
540538
541539 # If using multiple devices, make sure all processes are unanimous on the decision.
542540 should_update_best_and_save = trainer .strategy .reduce_boolean_decision (bool (should_update_best_and_save ))
@@ -748,22 +746,19 @@ def _update_best_and_save(
748746
749747 # save the current score
750748 self .current_score = current
751- self .best_k_models [filepath ] = {
752- "score" : current ,
753- "metrics" : monitor_candidates ,
754- }
749+ self .best_k_models [filepath ] = current
755750
756751 if len (self .best_k_models ) == k :
757752 # monitor dict has reached k elements
758753 _op = max if self .mode == "min" else min
759754 self .kth_best_model_path = _op (self .best_k_models , key = self .best_k_models .get ) # type: ignore[arg-type]
760755 self .kth_value = self .best_k_models [self .kth_best_model_path ]
761- self .kth_model_metrics = self .best_k_models [self .kth_best_model_path ]["metrics" ]
762756
763757 _op = min if self .mode == "min" else max
764758 self .best_model_path = _op (self .best_k_models , key = self .best_k_models .get ) # type: ignore[arg-type]
765- self .best_model_score = self .best_k_models [self .best_model_path ]["score" ]
766- self .best_model_metrics = self .best_k_models [self .best_model_path ]["metrics" ]
759+ self .best_model_score = self .best_k_models [self .best_model_path ]
760+ if self .best_model_path == filepath :
761+ self .best_model_metrics = monitor_candidates
767762
768763 if self .verbose :
769764 epoch = monitor_candidates ["epoch" ]
@@ -780,7 +775,7 @@ def _update_best_and_save(
780775 def to_yaml (self , filepath : Optional [_PATH ] = None ) -> None :
781776 """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
782777 file."""
783- best_k = {k : v [ "score" ] .item () for k , v in self .best_k_models .items ()} # type: ignore[arg-type]
778+ best_k = {k : v .item () for k , v in self .best_k_models .items ()}
784779 if filepath is None :
785780 assert self .dirpath
786781 filepath = os .path .join (self .dirpath , "best_k_models.yaml" )
0 commit comments