Skip to content

Commit 585e7a6

Browse files
authored
Merge pull request #30 from grok-ai/develop
Hotfix load model method
2 parents fa4b2f1 + cabf23e commit 585e7a6

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/nn_core/serialization.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import zipfile
99
from contextlib import contextmanager
1010
from pathlib import Path
11-
from typing import Any, Callable, Dict, Optional, Type, Union
11+
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
1212

1313
import pytorch_lightning as pl
1414
import torch
@@ -151,18 +151,19 @@ def _substistute(dictionary, substitute_values: Dict[str, str], substitute_keys:
151151
def load_model(
152152
module_class: Type[pl.LightningModule],
153153
checkpoint_path: Path,
154+
strict: bool = True,
154155
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
155156
substitute_keys: Optional[Dict[str, str]] = None,
156157
substitute_values: Optional[Dict[str, str]] = None,
157-
):
158+
) -> Tuple[pl.LightningModule, Dict[str, Any]]:
158159
# Lightning checkpoints end with .ckpt, ours with .ckpt.zip
159160
if checkpoint_path.name.endswith(".ckpt.zip"):
160161
checkpoint = NNCheckpointIO.load(path=checkpoint_path, map_location=map_location)
161162

162163
if substitute_values is not None:
163164
checkpoint = _substistute(checkpoint, substitute_values=substitute_values, substitute_keys=substitute_keys)
164165

165-
return _load_state(cls=module_class, checkpoint=checkpoint, metadata=checkpoint.get("metadata", None))
166+
return _load_state(cls=module_class, checkpoint=checkpoint, strict=strict, metadata=checkpoint.get("metadata", None)), checkpoint
166167
else:
167168
pylogger.warning(f"Loading a legacy checkpoint (from vanilla PyTorch Lightning): '{checkpoint_path}'")
168-
module_class.load_from_checkpoint(checkpoint_path=str(checkpoint_path), map_location=map_location)
169+
return module_class.load_from_checkpoint(checkpoint_path=str(checkpoint_path), map_location=map_location), None

0 commit comments

Comments
 (0)