|
8 | 8 | import zipfile |
9 | 9 | from contextlib import contextmanager |
10 | 10 | from pathlib import Path |
11 | | -from typing import Any, Callable, Dict, Optional, Type, Union |
| 11 | +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union |
12 | 12 |
|
13 | 13 | import pytorch_lightning as pl |
14 | 14 | import torch |
@@ -151,18 +151,19 @@ def _substistute(dictionary, substitute_values: Dict[str, str], substitute_keys: |
151 | 151 | def load_model( |
152 | 152 | module_class: Type[pl.LightningModule], |
153 | 153 | checkpoint_path: Path, |
| 154 | + strict: bool = True, |
154 | 155 | map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, |
155 | 156 | substitute_keys: Optional[Dict[str, str]] = None, |
156 | 157 | substitute_values: Optional[Dict[str, str]] = None, |
157 | | -): |
| 158 | +) -> Tuple[pl.LightningModule, Dict[str, Any]]: |
158 | 159 | # Lightning checkpoints end with .ckpt, ours with .ckpt.zip |
159 | 160 | if checkpoint_path.name.endswith(".ckpt.zip"): |
160 | 161 | checkpoint = NNCheckpointIO.load(path=checkpoint_path, map_location=map_location) |
161 | 162 |
|
162 | 163 | if substitute_values is not None: |
163 | 164 | checkpoint = _substistute(checkpoint, substitute_values=substitute_values, substitute_keys=substitute_keys) |
164 | 165 |
|
165 | | - return _load_state(cls=module_class, checkpoint=checkpoint, metadata=checkpoint.get("metadata", None)) |
| 166 | + return _load_state(cls=module_class, checkpoint=checkpoint, strict=strict, metadata=checkpoint.get("metadata", None)), checkpoint |
166 | 167 | else: |
167 | 168 | pylogger.warning(f"Loading a legacy checkpoint (from vanilla PyTorch Lightning): '{checkpoint_path}'") |
168 | | - module_class.load_from_checkpoint(checkpoint_path=str(checkpoint_path), map_location=map_location) |
| 169 | + return module_class.load_from_checkpoint(checkpoint_path=str(checkpoint_path), map_location=map_location), None |
0 commit comments