@@ -1638,7 +1638,7 @@ def training_step(self, *args):
16381638
16391639
16401640def test_best_model_metrics (tmp_path ):
1641- """Ensure ModelCheckpoint correctly tracks best_model_metrics."""
1641+ """Ensure ModelCheckpoint correctly tracks and restores best_model_metrics."""
16421642
16431643 class TestModel (BoringModel ):
16441644 def training_step (self , batch , batch_idx ):
@@ -1654,7 +1654,12 @@ def validation_step(self, batch, batch_idx):
16541654 self .log ("val_metric" , (self .current_epoch + 1 ) / 10 )
16551655 return loss
16561656
1657- checkpoint = ModelCheckpoint (dirpath = tmp_path , save_top_k = 3 , monitor = "val_metric" , mode = "min" )
1657+ checkpoint = ModelCheckpoint (
1658+ dirpath = tmp_path ,
1659+ save_top_k = 3 ,
1660+ monitor = "val_metric" ,
1661+ mode = "min" ,
1662+ )
16581663
16591664 trainer = Trainer (
16601665 default_root_dir = tmp_path ,
@@ -1672,15 +1677,37 @@ def validation_step(self, batch, batch_idx):
16721677 assert hasattr (checkpoint , "best_model_metrics" )
16731678 assert isinstance (checkpoint .best_model_metrics , dict )
16741679 assert "val_metric" in checkpoint .best_model_metrics
1675- assert checkpoint .best_model_metrics ["val_metric" ] == 0.1 # best ( lowest) value
1680+ assert checkpoint .best_model_metrics ["val_metric" ] == 0.1 # lowest value
16761681 assert "val_loss" in checkpoint .best_model_metrics
16771682 assert "train_loss" in checkpoint .best_model_metrics
16781683 assert "train_metric" in checkpoint .best_model_metrics
16791684
1685+ best_ckpt_path = checkpoint .best_model_path
1686+ assert best_ckpt_path
1687+ assert os .path .exists (best_ckpt_path )
1688+
1689+ loaded = torch .load (best_ckpt_path , weights_only = False )
1690+
1691+ callbacks_state = loaded .get ("callbacks" , {})
1692+ assert callbacks_state # ensure not empty
1693+
1694+ ckpt_key = next (
1695+ (k for k in callbacks_state if k .startswith ("ModelCheckpoint" )),
1696+ None ,
1697+ )
1698+
1699+ assert ckpt_key is not None
1700+
1701+ loaded_metrics = callbacks_state [ckpt_key ]["best_model_metrics" ]
1702+
1703+ assert isinstance (loaded_metrics , dict )
1704+ assert loaded_metrics == checkpoint .best_model_metrics
1705+ assert loaded_metrics ["val_metric" ] == 0.1
1706+
16801707
16811708@pytest .mark .parametrize ("mode" , ["min" , "max" ])
16821709def test_best_model_metrics_mode (tmp_path , mode : str ):
1683- """Ensure ModelCheckpoint.best_model_metrics respects the 'mode' parameter."""
1710+ """Ensure ModelCheckpoint.best_model_metrics respects the 'mode' parameter and is restored correctly ."""
16841711
16851712 class TestModel (BoringModel ):
16861713 def validation_step (self , batch , batch_idx ):
@@ -1710,6 +1737,26 @@ def validation_step(self, batch, batch_idx):
17101737 expected_value = 0.1 if mode == "min" else 0.3
17111738 assert checkpoint .best_model_metrics ["val_metric" ] == expected_value
17121739
1740+ # load the checkpoint and verify metrics are restored
1741+ best_ckpt_path = checkpoint .best_model_path
1742+ assert best_ckpt_path
1743+ assert os .path .exists (best_ckpt_path )
1744+
1745+ loaded = torch .load (best_ckpt_path , weights_only = False )
1746+ callbacks_state = loaded .get ("callbacks" , {})
1747+ assert callbacks_state
1748+
1749+ ckpt_key = next (
1750+ (k for k in callbacks_state if k .startswith ("ModelCheckpoint" )),
1751+ None ,
1752+ )
1753+ assert ckpt_key is not None
1754+
1755+ loaded_metrics = callbacks_state [ckpt_key ]["best_model_metrics" ]
1756+
1757+ assert isinstance (loaded_metrics , dict )
1758+ assert loaded_metrics ["val_metric" ] == expected_value
1759+
17131760
17141761@pytest .mark .parametrize ("use_omegaconf" , [False , pytest .param (True , marks = RunIf (omegaconf = True ))])
17151762def test_hparams_type (tmp_path , use_omegaconf ):
0 commit comments