2727from copy import deepcopy
2828from datetime import timedelta
2929from pathlib import Path
30- from typing import Any , Dict , Literal , Optional , Set , Union
30+ from typing import Any , Dict , Literal , Optional , Set , Union , cast
3131from weakref import proxy
3232
3333import torch
3434import yaml
3535from torch import Tensor
3636from typing_extensions import override
3737
38- import lightning .pytorch as pl
39- from lightning .fabric .utilities .cloud_io import _is_dir , _is_local_file_protocol , get_filesystem
40- from lightning .fabric .utilities .types import _PATH
41- from lightning .pytorch .callbacks import Checkpoint
42- from lightning .pytorch .utilities .exceptions import MisconfigurationException
43- from lightning .pytorch .utilities .rank_zero import WarningCache , rank_zero_info , rank_zero_warn
44- from lightning .pytorch .utilities .types import STEP_OUTPUT
38+ import pytorch_lightning as pl
39+ from lightning_fabric .utilities .cloud_io import (
40+ _is_dir ,
41+ _is_local_file_protocol ,
42+ get_filesystem ,
43+ )
44+ from lightning_fabric .utilities .types import _PATH
45+ from pytorch_lightning .callbacks import Checkpoint
46+ from pytorch_lightning .utilities .exceptions import MisconfigurationException
47+ from pytorch_lightning .utilities .rank_zero import (
48+ WarningCache ,
49+ rank_zero_info ,
50+ rank_zero_warn ,
51+ )
52+ from pytorch_lightning .utilities .types import STEP_OUTPUT
4553
4654log = logging .getLogger (__name__ )
4755warning_cache = WarningCache ()
@@ -241,9 +249,10 @@ def __init__(
241249 self ._last_global_step_saved = 0 # no need to save when no steps were taken
242250 self ._last_time_checked : Optional [float ] = None
243251 self .current_score : Optional [Tensor ] = None
244- self .best_k_models : Dict [str , Tensor ] = {}
252+ self .best_k_models : Dict [str , Dict [ str , Tensor | Dict [ str , Tensor ]] ] = {}
245253 self .kth_best_model_path = ""
246254 self .best_model_score : Optional [Tensor ] = None
255+ self .best_model_metrics : Optional [Dict [str , Tensor ]] = None
247256 self .best_model_path = ""
248257 self .last_model_path = ""
249258 self ._last_checkpoint_saved = ""
@@ -339,6 +348,7 @@ def state_dict(self) -> Dict[str, Any]:
339348 return {
340349 "monitor" : self .monitor ,
341350 "best_model_score" : self .best_model_score ,
351+ "best_model_metrics" : self .best_model_metrics ,
342352 "best_model_path" : self .best_model_path ,
343353 "current_score" : self .current_score ,
344354 "dirpath" : self .dirpath ,
@@ -354,6 +364,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
354364
355365 if self .dirpath == dirpath_from_ckpt :
356366 self .best_model_score = state_dict ["best_model_score" ]
367+ self .best_model_metrics = state_dict ["best_model_metrics" ]
357368 self .kth_best_model_path = state_dict .get ("kth_best_model_path" , self .kth_best_model_path )
358369 self .kth_value = state_dict .get ("kth_value" , self .kth_value )
359370 self .best_k_models = state_dict .get ("best_k_models" , self .best_k_models )
@@ -523,7 +534,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
523534 return True
524535
525536 monitor_op = {"min" : torch .lt , "max" : torch .gt }[self .mode ]
526- should_update_best_and_save = monitor_op (current , self .best_k_models [self .kth_best_model_path ])
537+ should_update_best_and_save = monitor_op (
538+ current , cast (Tensor , self .best_k_models [self .kth_best_model_path ]["score" ])
539+ )
527540
528541 # If using multiple devices, make sure all processes are unanimous on the decision.
529542 should_update_best_and_save = trainer .strategy .reduce_boolean_decision (bool (should_update_best_and_save ))
@@ -735,17 +748,22 @@ def _update_best_and_save(
735748
736749 # save the current score
737750 self .current_score = current
738- self .best_k_models [filepath ] = current
751+ self .best_k_models [filepath ] = {
752+ "score" : current ,
753+ "metrics" : monitor_candidates ,
754+ }
739755
740756 if len (self .best_k_models ) == k :
741757 # monitor dict has reached k elements
742758 _op = max if self .mode == "min" else min
743759 self .kth_best_model_path = _op (self .best_k_models , key = self .best_k_models .get ) # type: ignore[arg-type]
744760 self .kth_value = self .best_k_models [self .kth_best_model_path ]
761+ self .kth_model_metrics = self .best_k_models [self .kth_best_model_path ]["metrics" ]
745762
746763 _op = min if self .mode == "min" else max
747764 self .best_model_path = _op (self .best_k_models , key = self .best_k_models .get ) # type: ignore[arg-type]
748- self .best_model_score = self .best_k_models [self .best_model_path ]
765+ self .best_model_score = self .best_k_models [self .best_model_path ]["score" ]
766+ self .best_model_metrics = self .best_k_models [self .best_model_path ]["metrics" ]
749767
750768 if self .verbose :
751769 epoch = monitor_candidates ["epoch" ]
@@ -762,7 +780,7 @@ def _update_best_and_save(
762780 def to_yaml (self , filepath : Optional [_PATH ] = None ) -> None :
763781 """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
764782 file."""
765- best_k = {k : v .item () for k , v in self .best_k_models .items ()}
783+ best_k = {k : v [ "score" ] .item () for k , v in self .best_k_models .items ()} # type: ignore[arg-type]
766784 if filepath is None :
767785 assert self .dirpath
768786 filepath = os .path .join (self .dirpath , "best_k_models.yaml" )
0 commit comments