From 8df37c231057c0f3f73bcae36050ee4f763fd8c4 Mon Sep 17 00:00:00 2001 From: gonzachiar Date: Wed, 27 Nov 2024 19:38:12 -0300 Subject: [PATCH 1/5] add best_k_metrics parameter --- .../pytorch/callbacks/model_checkpoint.py | 44 +++++++++++++------ .../checkpointing/test_model_checkpoint.py | 24 ++++++---- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 9587da0f4600b..a00cd30ab4e17 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -27,7 +27,7 @@ from copy import deepcopy from datetime import timedelta from pathlib import Path -from typing import Any, Dict, Literal, Optional, Set, Union +from typing import Any, Dict, Literal, Optional, Set, Union, cast from weakref import proxy import torch @@ -35,13 +35,21 @@ from torch import Tensor from typing_extensions import override -import lightning.pytorch as pl -from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem -from lightning.fabric.utilities.types import _PATH -from lightning.pytorch.callbacks import Checkpoint -from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_info, rank_zero_warn -from lightning.pytorch.utilities.types import STEP_OUTPUT +import pytorch_lightning as pl +from lightning_fabric.utilities.cloud_io import ( + _is_dir, + _is_local_file_protocol, + get_filesystem, +) +from lightning_fabric.utilities.types import _PATH +from pytorch_lightning.callbacks import Checkpoint +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import ( + WarningCache, + rank_zero_info, + rank_zero_warn, +) +from pytorch_lightning.utilities.types import STEP_OUTPUT log = logging.getLogger(__name__) warning_cache = WarningCache() @@ -241,9 +249,10 @@ def __init__( self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score: Optional[Tensor] = None - self.best_k_models: Dict[str, Tensor] = {} + self.best_k_models: Dict[str, Dict[str, Tensor | Dict[str, Tensor]]] = {} self.kth_best_model_path = "" self.best_model_score: Optional[Tensor] = None + self.best_model_metrics: Optional[Dict[str, Tensor]] = None self.best_model_path = "" self.last_model_path = "" self._last_checkpoint_saved = "" @@ -339,6 +348,7 @@ def state_dict(self) -> Dict[str, Any]: return { "monitor": self.monitor, "best_model_score": self.best_model_score, + "best_model_metrics": self.best_model_metrics, "best_model_path": self.best_model_path, "current_score": self.current_score, "dirpath": self.dirpath, @@ -354,6 +364,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.dirpath == dirpath_from_ckpt: self.best_model_score = state_dict["best_model_score"] + self.best_model_metrics = state_dict["best_model_metrics"] self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path) self.kth_value = state_dict.get("kth_value", self.kth_value) self.best_k_models = state_dict.get("best_k_models", self.best_k_models) @@ -523,7 +534,9 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = return True monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] - should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) + should_update_best_and_save = monitor_op( + current, cast(Tensor, self.best_k_models[self.kth_best_model_path]["score"]) + ) # If using multiple devices, make sure all processes are unanimous on the decision. should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save)) @@ -735,17 +748,22 @@ def _update_best_and_save( # save the current score self.current_score = current - self.best_k_models[filepath] = current + self.best_k_models[filepath] = { + "score": current, + "metrics": monitor_candidates, + } if len(self.best_k_models) == k: # monitor dict has reached k elements _op = max if self.mode == "min" else min self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] self.kth_value = self.best_k_models[self.kth_best_model_path] + self.kth_model_metrics = self.best_k_models[self.kth_best_model_path]["metrics"] _op = min if self.mode == "min" else max self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] - self.best_model_score = self.best_k_models[self.best_model_path] + self.best_model_score = self.best_k_models[self.best_model_path]["score"] + self.best_model_metrics = self.best_k_models[self.best_model_path]["metrics"] if self.verbose: epoch = monitor_candidates["epoch"] @@ -762,7 +780,7 @@ def _update_best_and_save( def to_yaml(self, filepath: Optional[_PATH] = None) -> None: """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML file.""" - best_k = {k: v.item() for k, v in self.best_k_models.items()} + best_k = {k: v["score"].item() for k, v in self.best_k_models.items()} # type: ignore[arg-type] if filepath is None: assert self.dirpath filepath = os.path.join(self.dirpath, "best_k_models.yaml") diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 97d8d3c4d0e4a..ab7da81901b23 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -26,11 +26,14 @@ from unittest.mock import Mock, call, patch import cloudpickle -import lightning.pytorch as pl import pytest import torch import yaml from jsonargparse import ArgumentParser +from tests_pytorch.helpers.runif import RunIf +from torch import optim + +import lightning.pytorch as pl from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.pytorch import Trainer, seed_everything @@ -39,9 +42,6 @@ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE -from torch import optim - -from tests_pytorch.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: from omegaconf import Container, OmegaConf @@ -706,6 +706,7 @@ def test_model_checkpoint_save_last_none_monitor(tmp_path, caplog): assert checkpoint_callback.best_model_path == str(tmp_path / "epoch=1-step=20.ckpt") assert checkpoint_callback.last_model_path == str(tmp_path / "last.ckpt") assert checkpoint_callback.best_model_score is None + assert checkpoint_callback.best_model_metrics is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == "" @@ -812,6 +813,7 @@ def test_model_checkpoint_topk_zero(tmp_path): assert checkpoint_callback.monitor is None assert checkpoint_callback.best_model_path == "" assert checkpoint_callback.best_model_score is None + assert checkpoint_callback.best_model_metrics is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == "" # check that only the last ckpt was created @@ -891,11 +893,13 @@ def test_default_checkpoint_behavior(tmp_path): assert len(results) == 1 save_dir = tmp_path / "checkpoints" save_weights_only = trainer.checkpoint_callback.save_weights_only - save_mock.assert_has_calls([ - call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only), - call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only), - call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only), - ]) + save_mock.assert_has_calls( + [ + call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only), + call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only), + call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only), + ] + ) ckpts = os.listdir(save_dir) assert len(ckpts) == 1 assert ckpts[0] == "epoch=2-step=15.ckpt" @@ -1479,6 +1483,8 @@ def test_save_last_versioning(tmp_path): assert all(not os.path.islink(tmp_path / path) for path in set(os.listdir(tmp_path))) + + def test_none_monitor_saves_correct_best_model_path(tmp_path): mc = ModelCheckpoint(dirpath=tmp_path, monitor=None) trainer = Trainer(callbacks=mc) From 0b3322d59c65fd4c51d505b5a99b5bfffcde5fcd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Dec 2024 22:42:09 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/callbacks/model_checkpoint.py | 7 +++---- .../checkpointing/test_model_checkpoint.py | 20 ++++++++----------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index d35b4462dffe1..2d5d7df3e7c74 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -30,12 +30,9 @@ from typing import Any, Literal, Optional, Union, cast from weakref import proxy +import pytorch_lightning as pl import torch import yaml -from torch import Tensor -from typing_extensions import override - -import pytorch_lightning as pl from lightning_fabric.utilities.cloud_io import ( _is_dir, _is_local_file_protocol, @@ -50,6 +47,8 @@ rank_zero_warn, ) from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import Tensor +from typing_extensions import override log = logging.getLogger(__name__) warning_cache = WarningCache() diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index c4eb8370977ea..031fd32b8bbb2 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -28,9 +28,6 @@ import pytest import torch import yaml -from tests_pytorch.helpers.runif import RunIf -from torch import optim - from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint @@ -39,6 +36,9 @@ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE +from torch import optim + +from tests_pytorch.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: from omegaconf import Container, OmegaConf @@ -888,13 +888,11 @@ def test_default_checkpoint_behavior(tmp_path): assert len(results) == 1 save_dir = tmp_path / "checkpoints" save_weights_only = trainer.checkpoint_callback.save_weights_only - save_mock.assert_has_calls( - [ - call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only), - call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only), - call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only), - ] - ) + save_mock.assert_has_calls([ + call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only), + call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only), + call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only), + ]) ckpts = os.listdir(save_dir) assert len(ckpts) == 1 assert ckpts[0] == "epoch=2-step=15.ckpt" @@ -1478,8 +1476,6 @@ def test_save_last_versioning(tmp_path): assert all(not os.path.islink(tmp_path / path) for path in set(os.listdir(tmp_path))) - - def test_none_monitor_saves_correct_best_model_path(tmp_path): mc = ModelCheckpoint(dirpath=tmp_path, monitor=None) trainer = Trainer(callbacks=mc) From a2141548ed3d56c7364f9e26e1f39075f3e6001f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Apr 2025 07:20:36 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/callbacks/model_checkpoint.py | 7 ++++--- tests/tests_pytorch/checkpointing/test_model_checkpoint.py | 1 - 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 2d5d7df3e7c74..d35b4462dffe1 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -30,9 +30,12 @@ from typing import Any, Literal, Optional, Union, cast from weakref import proxy -import pytorch_lightning as pl import torch import yaml +from torch import Tensor +from typing_extensions import override + +import pytorch_lightning as pl from lightning_fabric.utilities.cloud_io import ( _is_dir, _is_local_file_protocol, @@ -47,8 +50,6 @@ rank_zero_warn, ) from pytorch_lightning.utilities.types import STEP_OUTPUT -from torch import Tensor -from typing_extensions import override log = logging.getLogger(__name__) warning_cache = WarningCache() diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 7a581280536ae..c0e4c3f2c4f2b 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -32,7 +32,6 @@ from torch import optim from torch.utils.data.dataloader import DataLoader -import lightning.pytorch as pl from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint From 80c4bb65308384ffbb2e71628978449fe6363316 Mon Sep 17 00:00:00 2001 From: gonzachiar Date: Wed, 16 Apr 2025 10:19:56 -0300 Subject: [PATCH 4/5] fix: modify and sort imports --- .../pytorch/callbacks/model_checkpoint.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index d35b4462dffe1..e5cbf5e40971e 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -35,21 +35,21 @@ from torch import Tensor from typing_extensions import override -import pytorch_lightning as pl +import lightning.pytorch as pl +from lightning.pytorch.callbacks import Checkpoint +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.rank_zero import ( + WarningCache, + rank_zero_info, + rank_zero_warn, +) +from lightning.pytorch.utilities.types import STEP_OUTPUT from lightning_fabric.utilities.cloud_io import ( _is_dir, _is_local_file_protocol, get_filesystem, ) from lightning_fabric.utilities.types import _PATH -from pytorch_lightning.callbacks import Checkpoint -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.rank_zero import ( - WarningCache, - rank_zero_info, - rank_zero_warn, -) -from pytorch_lightning.utilities.types import STEP_OUTPUT log = logging.getLogger(__name__) warning_cache = WarningCache() @@ -252,7 +252,7 @@ def __init__( self.best_k_models: dict[str, dict[str, Tensor | dict[str, Tensor]]] = {} self.kth_best_model_path = "" self.best_model_score: Optional[Tensor] = None - self.best_model_metrics: Optional[Dict[str, Tensor]] = None + self.best_model_metrics: Optional[dict[str, Tensor]] = None self.best_model_path = "" self.last_model_path = "" self._last_checkpoint_saved = "" From 821bfef8219bf240a38eaa58d22e5d7069e4e026 Mon Sep 17 00:00:00 2001 From: gonzachiar Date: Wed, 16 Apr 2025 10:43:20 -0300 Subject: [PATCH 5/5] fix: revert changes on best_k_models --- .../pytorch/callbacks/model_checkpoint.py | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index e5cbf5e40971e..f771304a525c7 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -27,7 +27,7 @@ from copy import deepcopy from datetime import timedelta from pathlib import Path -from typing import Any, Literal, Optional, Union, cast +from typing import Any, Literal, Optional, Union from weakref import proxy import torch @@ -249,7 +249,7 @@ def __init__( self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score: Optional[Tensor] = None - self.best_k_models: dict[str, dict[str, Tensor | dict[str, Tensor]]] = {} + self.best_k_models: dict[str, Tensor] = {} self.kth_best_model_path = "" self.best_model_score: Optional[Tensor] = None self.best_model_metrics: Optional[dict[str, Tensor]] = None @@ -372,8 +372,8 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: else: warnings.warn( f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r}," - " therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and" - " `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded." + " therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path`," + " `best_k_models` and `best_model_metrics` won't be reloaded. Only `best_model_path` will be reloaded." ) self.best_model_path = state_dict["best_model_path"] @@ -534,9 +534,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = return True monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] - should_update_best_and_save = monitor_op( - current, cast(Tensor, self.best_k_models[self.kth_best_model_path]["score"]) - ) + should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) # If using multiple devices, make sure all processes are unanimous on the decision. should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save)) @@ -748,22 +746,19 @@ def _update_best_and_save( # save the current score self.current_score = current - self.best_k_models[filepath] = { - "score": current, - "metrics": monitor_candidates, - } + self.best_k_models[filepath] = current if len(self.best_k_models) == k: # monitor dict has reached k elements _op = max if self.mode == "min" else min self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] self.kth_value = self.best_k_models[self.kth_best_model_path] - self.kth_model_metrics = self.best_k_models[self.kth_best_model_path]["metrics"] _op = min if self.mode == "min" else max self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] - self.best_model_score = self.best_k_models[self.best_model_path]["score"] - self.best_model_metrics = self.best_k_models[self.best_model_path]["metrics"] + self.best_model_score = self.best_k_models[self.best_model_path] + if self.best_model_path == filepath: + self.best_model_metrics = monitor_candidates if self.verbose: epoch = monitor_candidates["epoch"] @@ -780,7 +775,7 @@ def _update_best_and_save( def to_yaml(self, filepath: Optional[_PATH] = None) -> None: """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML file.""" - best_k = {k: v["score"].item() for k, v in self.best_k_models.items()} # type: ignore[arg-type] + best_k = {k: v.item() for k, v in self.best_k_models.items()} if filepath is None: assert self.dirpath filepath = os.path.join(self.dirpath, "best_k_models.yaml")