Skip to content

Commit fa4b2f1

Browse files
authored
Merge pull request #25 from umbertov/nonstrict_state_dict
restore/finetune: allow non-strict loading of state dict
2 parents a96c150 + 28d1878 commit fa4b2f1

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/nn_core/callbacks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class NNTemplateCore(Callback):
1717
def __init__(self, restore_cfg: Optional[DictConfig]):
1818
self.resume_ckpt_path, self.resume_run_version = parse_restore(restore_cfg)
1919
self.restore_mode: Optional[str] = restore_cfg.get("mode", None) if restore_cfg is not None else None
20+
self.restore_strict: bool = restore_cfg.get("strict", True) if restore_cfg is not None else True
2021

2122
@property
2223
def resume_id(self) -> Optional[str]:
@@ -41,7 +42,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
4142
if self.restore_mode == "finetune":
4243
checkpoint = NNCheckpointIO.load(path=Path(self.resume_ckpt_path))
4344

44-
pl_module.load_state_dict(checkpoint["state_dict"])
45+
pl_module.load_state_dict(checkpoint["state_dict"], strict=self.restore_strict)
4546

4647
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
4748
if self._is_nnlogger(trainer):

0 commit comments

Comments
 (0)