diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 8a5d9dcdf786f..f61b10c30bafb 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -256,6 +256,7 @@ def __init__( self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score: Optional[Tensor] = None + self.best_model_metrics: Optional[dict[str, Tensor]] = {} self.best_k_models: dict[str, Tensor] = {} self.kth_best_model_path = "" self.best_model_score: Optional[Tensor] = None @@ -858,6 +859,9 @@ def _update_best_and_save( self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] self.best_model_score = self.best_k_models[self.best_model_path] + if self.best_model_path == filepath: + self.best_model_metrics = dict(monitor_candidates) + if self.verbose: epoch = monitor_candidates["epoch"] step = monitor_candidates["step"] diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 449484da970a8..07bae39e42bac 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -1637,6 +1637,80 @@ def training_step(self, *args): assert model_checkpoint.current_score == expected +def test_best_model_metrics(tmp_path): + """Ensure ModelCheckpoint correctly tracks best_model_metrics.""" + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + loss = super().training_step(batch, batch_idx) + self.log("train_loss", loss["loss"]) + self.log("train_metric", (self.current_epoch + 1) / 10) + return loss + + def validation_step(self, batch, batch_idx): + output = super().validation_step(batch, batch_idx) + loss = output["x"] + self.log("val_loss", loss) + self.log("val_metric", (self.current_epoch + 1) / 10) + return loss + + checkpoint = ModelCheckpoint(dirpath=tmp_path, save_top_k=3, monitor="val_metric", mode="min") + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=3, + limit_train_batches=1, + limit_val_batches=1, + callbacks=[checkpoint], + logger=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + + trainer.fit(TestModel()) + + assert hasattr(checkpoint, "best_model_metrics") + assert isinstance(checkpoint.best_model_metrics, dict) + assert "val_metric" in checkpoint.best_model_metrics + assert checkpoint.best_model_metrics["val_metric"] == 0.1 # best (lowest) value + assert "val_loss" in checkpoint.best_model_metrics + assert "train_loss" in checkpoint.best_model_metrics + assert "train_metric" in checkpoint.best_model_metrics + + +@pytest.mark.parametrize("mode", ["min", "max"]) +def test_best_model_metrics_mode(tmp_path, mode: str): + """Ensure ModelCheckpoint.best_model_metrics respects the 'mode' parameter.""" + + class TestModel(BoringModel): + def validation_step(self, batch, batch_idx): + output = super().validation_step(batch, batch_idx) + metric_value = (self.current_epoch + 1) / 10 + self.log("val_metric", metric_value) + return output["x"] + + checkpoint = ModelCheckpoint(dirpath=tmp_path, save_top_k=1, monitor="val_metric", mode=mode) + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=3, + limit_train_batches=1, + limit_val_batches=1, + callbacks=[checkpoint], + logger=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + + trainer.fit(TestModel()) + + assert checkpoint.best_model_metrics is not None + assert "val_metric" in checkpoint.best_model_metrics + + expected_value = 0.1 if mode == "min" else 0.3 + assert checkpoint.best_model_metrics["val_metric"] == expected_value + + @pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))]) def test_hparams_type(tmp_path, use_omegaconf): class TestModel(BoringModel):