Skip to content

Commit a1e0681

Browse files
committed
Add support for saving and restoring best_model_metrics in ModelCheckpoint
1 parent 53a1234 commit a1e0681

File tree

3 files changed

+55
-5
lines changed

3 files changed

+55
-5
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Added
1212

13-
-
13+
- Added `best_model_metrics` attribute to `ModelCheckpoint` callback to store all logged metrics associated with the best model checkpoint ([#21355](https://github.com/Lightning-AI/pytorch-lightning/pull/21355))
1414

1515
### Changed
1616

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,13 @@ def state_dict(self) -> dict[str, Any]:
551551
"kth_best_model_path": self.kth_best_model_path,
552552
"kth_value": self.kth_value,
553553
"last_model_path": self.last_model_path,
554+
"best_model_metrics": self.best_model_metrics,
554555
}
555556

556557
@override
557558
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
559+
self.best_model_metrics = state_dict.get("best_model_metrics", {})
560+
558561
dirpath_from_ckpt = state_dict.get("dirpath", self.dirpath)
559562

560563
if self.dirpath == dirpath_from_ckpt:

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,7 +1638,7 @@ def training_step(self, *args):
16381638

16391639

16401640
def test_best_model_metrics(tmp_path):
1641-
"""Ensure ModelCheckpoint correctly tracks best_model_metrics."""
1641+
"""Ensure ModelCheckpoint correctly tracks and restores best_model_metrics."""
16421642

16431643
class TestModel(BoringModel):
16441644
def training_step(self, batch, batch_idx):
@@ -1654,7 +1654,12 @@ def validation_step(self, batch, batch_idx):
16541654
self.log("val_metric", (self.current_epoch + 1) / 10)
16551655
return loss
16561656

1657-
checkpoint = ModelCheckpoint(dirpath=tmp_path, save_top_k=3, monitor="val_metric", mode="min")
1657+
checkpoint = ModelCheckpoint(
1658+
dirpath=tmp_path,
1659+
save_top_k=3,
1660+
monitor="val_metric",
1661+
mode="min",
1662+
)
16581663

16591664
trainer = Trainer(
16601665
default_root_dir=tmp_path,
@@ -1672,15 +1677,37 @@ def validation_step(self, batch, batch_idx):
16721677
assert hasattr(checkpoint, "best_model_metrics")
16731678
assert isinstance(checkpoint.best_model_metrics, dict)
16741679
assert "val_metric" in checkpoint.best_model_metrics
1675-
assert checkpoint.best_model_metrics["val_metric"] == 0.1 # best (lowest) value
1680+
assert checkpoint.best_model_metrics["val_metric"] == 0.1 # lowest value
16761681
assert "val_loss" in checkpoint.best_model_metrics
16771682
assert "train_loss" in checkpoint.best_model_metrics
16781683
assert "train_metric" in checkpoint.best_model_metrics
16791684

1685+
best_ckpt_path = checkpoint.best_model_path
1686+
assert best_ckpt_path
1687+
assert os.path.exists(best_ckpt_path)
1688+
1689+
loaded = torch.load(best_ckpt_path, weights_only=False)
1690+
1691+
callbacks_state = loaded.get("callbacks", {})
1692+
assert callbacks_state # ensure not empty
1693+
1694+
ckpt_key = next(
1695+
(k for k in callbacks_state if k.startswith("ModelCheckpoint")),
1696+
None,
1697+
)
1698+
1699+
assert ckpt_key is not None
1700+
1701+
loaded_metrics = callbacks_state[ckpt_key]["best_model_metrics"]
1702+
1703+
assert isinstance(loaded_metrics, dict)
1704+
assert loaded_metrics == checkpoint.best_model_metrics
1705+
assert loaded_metrics["val_metric"] == 0.1
1706+
16801707

16811708
@pytest.mark.parametrize("mode", ["min", "max"])
16821709
def test_best_model_metrics_mode(tmp_path, mode: str):
1683-
"""Ensure ModelCheckpoint.best_model_metrics respects the 'mode' parameter."""
1710+
"""Ensure ModelCheckpoint.best_model_metrics respects the 'mode' parameter and is restored correctly."""
16841711

16851712
class TestModel(BoringModel):
16861713
def validation_step(self, batch, batch_idx):
@@ -1710,6 +1737,26 @@ def validation_step(self, batch, batch_idx):
17101737
expected_value = 0.1 if mode == "min" else 0.3
17111738
assert checkpoint.best_model_metrics["val_metric"] == expected_value
17121739

1740+
# load the checkpoint and verify metrics are restored
1741+
best_ckpt_path = checkpoint.best_model_path
1742+
assert best_ckpt_path
1743+
assert os.path.exists(best_ckpt_path)
1744+
1745+
loaded = torch.load(best_ckpt_path, weights_only=False)
1746+
callbacks_state = loaded.get("callbacks", {})
1747+
assert callbacks_state
1748+
1749+
ckpt_key = next(
1750+
(k for k in callbacks_state if k.startswith("ModelCheckpoint")),
1751+
None,
1752+
)
1753+
assert ckpt_key is not None
1754+
1755+
loaded_metrics = callbacks_state[ckpt_key]["best_model_metrics"]
1756+
1757+
assert isinstance(loaded_metrics, dict)
1758+
assert loaded_metrics["val_metric"] == expected_value
1759+
17131760

17141761
@pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))])
17151762
def test_hparams_type(tmp_path, use_omegaconf):

0 commit comments

Comments
 (0)