Skip to content
35 changes: 22 additions & 13 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,18 @@ def __init__(
self.auto_insert_metric_name = auto_insert_metric_name
self._save_on_train_epoch_end = save_on_train_epoch_end
self._enable_version_counter = enable_version_counter
self._dirpath = dirpath
self._filename = filename
self._mode = mode

self.kth_value: Optional[Tensor] = None
self.dirpath: Optional[_PATH]

self.__init_state()
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
self.__validate_init_configuration()

def __init_state(self) -> None:
self._last_global_step_saved = 0 # no need to save when no steps were taken
self._last_time_checked: Optional[float] = None
self.current_score: Optional[Tensor] = None
Expand All @@ -248,26 +260,23 @@ def __init__(
self.last_model_path = ""
self._last_checkpoint_saved = ""

self.kth_value: Tensor
self.dirpath: Optional[_PATH]
self.__init_monitor_mode(mode)
self.__init_ckpt_dir(dirpath, filename)
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
self.__validate_init_configuration()

@property
@override
def state_key(self) -> str:
return self._generate_state_key(
monitor=self.monitor,
mode=self.mode,
mode=self._mode,
every_n_train_steps=self._every_n_train_steps,
every_n_epochs=self._every_n_epochs,
train_time_interval=self._train_time_interval,
)

@override
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
self.__init_state()
self.__set_monitor_mode(self._mode)
self.__set_ckpt_dir(self._dirpath, self._filename)

dirpath = self.__resolve_ckpt_dir(trainer)
dirpath = trainer.strategy.broadcast(dirpath)
self.dirpath = dirpath
Expand Down Expand Up @@ -341,7 +350,7 @@ def state_dict(self) -> dict[str, Any]:
"best_model_score": self.best_model_score,
"best_model_path": self.best_model_path,
"current_score": self.current_score,
"dirpath": self.dirpath,
"dirpath": self._dirpath,
"best_k_models": self.best_k_models,
"kth_best_model_path": self.kth_best_model_path,
"kth_value": self.kth_value,
Expand Down Expand Up @@ -469,7 +478,7 @@ def __validate_init_configuration(self) -> None:
" configuration. No quantity for top_k to track."
)

def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
def __set_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
self._fs = get_filesystem(dirpath if dirpath else "")

if dirpath and _is_local_file_protocol(dirpath if dirpath else ""):
Expand All @@ -478,7 +487,7 @@ def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) ->
self.dirpath = dirpath
self.filename = filename

def __init_monitor_mode(self, mode: str) -> None:
def __set_monitor_mode(self, mode: str) -> None:
torch_inf = torch.tensor(torch.inf)
mode_dict = {"min": (torch_inf, "min"), "max": (-torch_inf, "max")}

Expand Down Expand Up @@ -618,9 +627,9 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH:
The path gets extended with subdirectory "checkpoints".

"""
if self.dirpath is not None:
if self._dirpath is not None:
# short circuit if dirpath was passed to ModelCheckpoint
return self.dirpath
return self._dirpath

if len(trainer.loggers) > 0:
if trainer.loggers[0].save_dir is not None:
Expand Down
Loading