Skip to content

Commit 28d1878

Browse files
umbertovlucmos
authored andcommitted
fix mistake and improve typing
just noticed i made a logic mistake at line 45, this commit is fixing that
1 parent e9a2554 commit 28d1878

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/nn_core/callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class NNTemplateCore(Callback):
1717
def __init__(self, restore_cfg: Optional[DictConfig]):
1818
self.resume_ckpt_path, self.resume_run_version = parse_restore(restore_cfg)
1919
self.restore_mode: Optional[str] = restore_cfg.get("mode", None) if restore_cfg is not None else None
20-
self.restore_strict: Optional[bool] = restore_cfg.get("strict", None) if restore_cfg is not None else None
20+
self.restore_strict: bool = restore_cfg.get("strict", True) if restore_cfg is not None else True
2121

2222
@property
2323
def resume_id(self) -> Optional[str]:
@@ -42,7 +42,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
4242
if self.restore_mode == "finetune":
4343
checkpoint = NNCheckpointIO.load(path=Path(self.resume_ckpt_path))
4444

45-
pl_module.load_state_dict(checkpoint["state_dict"], strict=self.restore_strict or True)
45+
pl_module.load_state_dict(checkpoint["state_dict"], strict=self.restore_strict)
4646

4747
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
4848
if self._is_nnlogger(trainer):

0 commit comments

Comments
 (0)