|
12 | 12 |
|
13 | 13 | import pytorch_lightning as pl |
14 | 14 | import torch |
| 15 | +from pytorch_lightning.core.saving import _load_state |
15 | 16 | from pytorch_lightning.plugins import TorchCheckpointIO |
16 | 17 |
|
17 | 18 | METADATA_KEY: str = "metadata" |
@@ -94,7 +95,7 @@ def remove_checkpoint(self, path) -> None: |
94 | 95 |
|
95 | 96 |
|
96 | 97 | def compress_checkpoint(src_dir: Path, dst_file: Path, delete_dir: bool = True): |
97 | | - |
| 98 | + dst_file.parent.mkdir(parents=True, exist_ok=True) |
98 | 99 | with zipfile.ZipFile(_normalize_path(dst_file), "w") as zip_file: |
99 | 100 | for folder, subfolders, files in os.walk(src_dir): |
100 | 101 | folder: Path = Path(folder) |
@@ -161,7 +162,7 @@ def load_model( |
161 | 162 | if substitute_values is not None: |
162 | 163 | checkpoint = _substistute(checkpoint, substitute_values=substitute_values, substitute_keys=substitute_keys) |
163 | 164 |
|
164 | | - return module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint.get("metadata", None)) |
| 165 | + return _load_state(cls=module_class, checkpoint=checkpoint, metadata=checkpoint.get("metadata", None)) |
165 | 166 | else: |
166 | 167 | pylogger.warning(f"Loading a legacy checkpoint (from vanilla PyTorch Lightning): '{checkpoint_path}'") |
167 | 168 | module_class.load_from_checkpoint(checkpoint_path=str(checkpoint_path), map_location=map_location) |
0 commit comments