From 3b01729ec31a1771bec7c7780a1932e152141a88 Mon Sep 17 00:00:00 2001 From: Ariel Bereslavsky Date: Thu, 15 Aug 2024 09:21:12 +0300 Subject: [PATCH 1/6] fix: init state on setup --- .../pytorch/callbacks/model_checkpoint.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 9587da0f4600b..dc892745565e2 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -238,6 +238,18 @@ def __init__( self.auto_insert_metric_name = auto_insert_metric_name self._save_on_train_epoch_end = save_on_train_epoch_end self._enable_version_counter = enable_version_counter + self._dirpath = dirpath + self._filename = filename + self._mode = mode + + self.kth_value: Tensor + self.dirpath: Optional[_PATH] + + self.__init_state() + self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) + self.__validate_init_configuration() + + def __init_state(self): self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score: Optional[Tensor] = None @@ -248,12 +260,6 @@ def __init__( self.last_model_path = "" self._last_checkpoint_saved = "" - self.kth_value: Tensor - self.dirpath: Optional[_PATH] - self.__init_monitor_mode(mode) - self.__init_ckpt_dir(dirpath, filename) - self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) - self.__validate_init_configuration() @property @override @@ -268,6 +274,10 @@ def state_key(self) -> str: @override def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + self.__init_state() + self.__set_monitor_mode(self._mode) + self.__set_ckpt_dir(self._dirpath, self._filename) + dirpath = self.__resolve_ckpt_dir(trainer) dirpath = trainer.strategy.broadcast(dirpath) self.dirpath = dirpath @@ -469,7 +479,7 @@ def __validate_init_configuration(self) -> None: " configuration. No quantity for top_k to track." ) - def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None: + def __set_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None: self._fs = get_filesystem(dirpath if dirpath else "") if dirpath and _is_local_file_protocol(dirpath if dirpath else ""): @@ -478,7 +488,7 @@ def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> self.dirpath = dirpath self.filename = filename - def __init_monitor_mode(self, mode: str) -> None: + def __set_monitor_mode(self, mode: str) -> None: torch_inf = torch.tensor(torch.inf) mode_dict = {"min": (torch_inf, "min"), "max": (-torch_inf, "max")} @@ -618,9 +628,9 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH: The path gets extended with subdirectory "checkpoints". """ - if self.dirpath is not None: + if self._dirpath is not None: # short circuit if dirpath was passed to ModelCheckpoint - return self.dirpath + return self._dirpath if len(trainer.loggers) > 0: if trainer.loggers[0].save_dir is not None: From aa29a1ff86802b166e014b7f917b0495d111fe30 Mon Sep 17 00:00:00 2001 From: Ariel Bereslavsky Date: Thu, 15 Aug 2024 09:36:29 +0300 Subject: [PATCH 2/6] fix: change state dict --- src/lightning/pytorch/callbacks/model_checkpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index dc892745565e2..0f17a38e0afe6 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -242,7 +242,7 @@ def __init__( self._filename = filename self._mode = mode - self.kth_value: Tensor + self.kth_value: Optional[Tensor] = None self.dirpath: Optional[_PATH] self.__init_state() @@ -266,7 +266,7 @@ def __init_state(self): def state_key(self) -> str: return self._generate_state_key( monitor=self.monitor, - mode=self.mode, + mode=self._mode, every_n_train_steps=self._every_n_train_steps, every_n_epochs=self._every_n_epochs, train_time_interval=self._train_time_interval, @@ -351,7 +351,7 @@ def state_dict(self) -> Dict[str, Any]: "best_model_score": self.best_model_score, "best_model_path": self.best_model_path, "current_score": self.current_score, - "dirpath": self.dirpath, + "dirpath": self._dirpath, "best_k_models": self.best_k_models, "kth_best_model_path": self.kth_best_model_path, "kth_value": self.kth_value, From ef6b47a8acb46f1720f0b25a2fe5392d0ab3fb3c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Aug 2024 06:39:45 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/callbacks/model_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 0f17a38e0afe6..46ef371e2f13d 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -260,7 +260,6 @@ def __init_state(self): self.last_model_path = "" self._last_checkpoint_saved = "" - @property @override def state_key(self) -> str: From eb22f2c1377aa247f5cbbac98912ee819dabdacf Mon Sep 17 00:00:00 2001 From: Ariel Bereslavsky Date: Thu, 15 Aug 2024 10:06:12 +0300 Subject: [PATCH 4/6] fix: add type annotation --- src/lightning/pytorch/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 0f17a38e0afe6..ffcfea56e5f46 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -249,7 +249,7 @@ def __init__( self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) self.__validate_init_configuration() - def __init_state(self): + def __init_state(self) -> None: self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score: Optional[Tensor] = None From 5197942fd903b07ff0e79bade3cab385f86d8240 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Fri, 14 Mar 2025 10:59:54 +0100 Subject: [PATCH 5/6] fixing --- .../pytorch/callbacks/model_checkpoint.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 555e4984551f6..7fd5f740eee85 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -238,12 +238,10 @@ def __init__( self.auto_insert_metric_name = auto_insert_metric_name self._save_on_train_epoch_end = save_on_train_epoch_end self._enable_version_counter = enable_version_counter - self._dirpath = dirpath - self._filename = filename - self._mode = mode - + self.dirpath: Optional[_PATH] = dirpath + self.filename = filename self.kth_value: Optional[Tensor] = None - self.dirpath: Optional[_PATH] + self._mode = mode self.__init_state() self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) @@ -275,7 +273,7 @@ def state_key(self) -> str: def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: self.__init_state() self.__set_monitor_mode(self._mode) - self.__set_ckpt_dir(self._dirpath, self._filename) + self.__set_ckpt_dir(self._dirpath, self.filename) dirpath = self.__resolve_ckpt_dir(trainer) dirpath = trainer.strategy.broadcast(dirpath) @@ -350,7 +348,7 @@ def state_dict(self) -> dict[str, Any]: "best_model_score": self.best_model_score, "best_model_path": self.best_model_path, "current_score": self.current_score, - "dirpath": self._dirpath, + "dirpath": self.dirpath, "best_k_models": self.best_k_models, "kth_best_model_path": self.kth_best_model_path, "kth_value": self.kth_value, @@ -627,9 +625,9 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH: The path gets extended with subdirectory "checkpoints". """ - if self._dirpath is not None: + if self.dirpath is not None: # short circuit if dirpath was passed to ModelCheckpoint - return self._dirpath + return self.dirpath if len(trainer.loggers) > 0: if trainer.loggers[0].save_dir is not None: From 6391a411a782f5a8faec5f2a9d20212477e2fb1f Mon Sep 17 00:00:00 2001 From: Jirka B Date: Fri, 14 Mar 2025 13:29:13 +0100 Subject: [PATCH 6/6] self._dirpath --- src/lightning/pytorch/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 7fd5f740eee85..ba59f693e67bb 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -273,7 +273,7 @@ def state_key(self) -> str: def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: self.__init_state() self.__set_monitor_mode(self._mode) - self.__set_ckpt_dir(self._dirpath, self.filename) + self.__set_ckpt_dir(self.dirpath, self.filename) dirpath = self.__resolve_ckpt_dir(trainer) dirpath = trainer.strategy.broadcast(dirpath)