Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
self._last_global_step_saved = 0 # no need to save when no steps were taken
self._last_time_checked: Optional[float] = None
self.current_score: Optional[Tensor] = None
self.best_model_metrics: Optional[dict[str, Tensor]] = {}
self.best_k_models: dict[str, Tensor] = {}
self.kth_best_model_path = ""
self.best_model_score: Optional[Tensor] = None
Expand Down Expand Up @@ -858,6 +859,9 @@ def _update_best_and_save(
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.best_model_score = self.best_k_models[self.best_model_path]

if self.best_model_path == filepath:
self.best_model_metrics = dict(monitor_candidates)

if self.verbose:
epoch = monitor_candidates["epoch"]
step = monitor_candidates["step"]
Expand Down
74 changes: 74 additions & 0 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,80 @@ def training_step(self, *args):
assert model_checkpoint.current_score == expected


def test_best_model_metrics(tmp_path):
"""Ensure ModelCheckpoint correctly tracks best_model_metrics."""

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)
self.log("train_loss", loss["loss"])
self.log("train_metric", (self.current_epoch + 1) / 10)
return loss

def validation_step(self, batch, batch_idx):
output = super().validation_step(batch, batch_idx)
loss = output["x"]
self.log("val_loss", loss)
self.log("val_metric", (self.current_epoch + 1) / 10)
return loss

checkpoint = ModelCheckpoint(dirpath=tmp_path, save_top_k=3, monitor="val_metric", mode="min")

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=3,
limit_train_batches=1,
limit_val_batches=1,
callbacks=[checkpoint],
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
)

trainer.fit(TestModel())

assert hasattr(checkpoint, "best_model_metrics")
assert isinstance(checkpoint.best_model_metrics, dict)
assert "val_metric" in checkpoint.best_model_metrics
assert checkpoint.best_model_metrics["val_metric"] == 0.1 # best (lowest) value
assert "val_loss" in checkpoint.best_model_metrics
assert "train_loss" in checkpoint.best_model_metrics
assert "train_metric" in checkpoint.best_model_metrics


@pytest.mark.parametrize("mode", ["min", "max"])
def test_best_model_metrics_mode(tmp_path, mode: str):
"""Ensure ModelCheckpoint.best_model_metrics respects the 'mode' parameter."""

class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
output = super().validation_step(batch, batch_idx)
metric_value = (self.current_epoch + 1) / 10
self.log("val_metric", metric_value)
return output["x"]

checkpoint = ModelCheckpoint(dirpath=tmp_path, save_top_k=1, monitor="val_metric", mode=mode)

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=3,
limit_train_batches=1,
limit_val_batches=1,
callbacks=[checkpoint],
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
)

trainer.fit(TestModel())

assert checkpoint.best_model_metrics is not None
assert "val_metric" in checkpoint.best_model_metrics

expected_value = 0.1 if mode == "min" else 0.3
assert checkpoint.best_model_metrics["val_metric"] == expected_value


@pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))])
def test_hparams_type(tmp_path, use_omegaconf):
class TestModel(BoringModel):
Expand Down