Skip to content

Commit 5197942

Browse files
committed
fixing
1 parent d2bf065 commit 5197942

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,10 @@ def __init__(
238238
self.auto_insert_metric_name = auto_insert_metric_name
239239
self._save_on_train_epoch_end = save_on_train_epoch_end
240240
self._enable_version_counter = enable_version_counter
241-
self._dirpath = dirpath
242-
self._filename = filename
243-
self._mode = mode
244-
241+
self.dirpath: Optional[_PATH] = dirpath
242+
self.filename = filename
245243
self.kth_value: Optional[Tensor] = None
246-
self.dirpath: Optional[_PATH]
244+
self._mode = mode
247245

248246
self.__init_state()
249247
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
@@ -275,7 +273,7 @@ def state_key(self) -> str:
275273
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
276274
self.__init_state()
277275
self.__set_monitor_mode(self._mode)
278-
self.__set_ckpt_dir(self._dirpath, self._filename)
276+
self.__set_ckpt_dir(self._dirpath, self.filename)
279277

280278
dirpath = self.__resolve_ckpt_dir(trainer)
281279
dirpath = trainer.strategy.broadcast(dirpath)
@@ -350,7 +348,7 @@ def state_dict(self) -> dict[str, Any]:
350348
"best_model_score": self.best_model_score,
351349
"best_model_path": self.best_model_path,
352350
"current_score": self.current_score,
353-
"dirpath": self._dirpath,
351+
"dirpath": self.dirpath,
354352
"best_k_models": self.best_k_models,
355353
"kth_best_model_path": self.kth_best_model_path,
356354
"kth_value": self.kth_value,
@@ -627,9 +625,9 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH:
627625
The path gets extended with subdirectory "checkpoints".
628626
629627
"""
630-
if self._dirpath is not None:
628+
if self.dirpath is not None:
631629
# short circuit if dirpath was passed to ModelCheckpoint
632-
return self._dirpath
630+
return self.dirpath
633631

634632
if len(trainer.loggers) > 0:
635633
if trainer.loggers[0].save_dir is not None:

0 commit comments

Comments
 (0)