Skip to content

Commit 2eab819

Browse files
committed
Update Lightning to 1.7.*
1 parent 6ecc71f commit 2eab819

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ packages=find:
1717
install_requires =
1818
# Add project specific dependencies
1919
# Stuff easy to break with updates
20-
pytorch-lightning>=1.5.8,<1.6
20+
pytorch-lightning==1.7.*
2121
hydra-core
2222
wandb
2323

src/nn_core/serialization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import pytorch_lightning as pl
1414
import torch
15+
from pytorch_lightning.core.saving import _load_state
1516
from pytorch_lightning.plugins import TorchCheckpointIO
1617

1718
METADATA_KEY: str = "metadata"
@@ -94,7 +95,7 @@ def remove_checkpoint(self, path) -> None:
9495

9596

9697
def compress_checkpoint(src_dir: Path, dst_file: Path, delete_dir: bool = True):
97-
98+
dst_file.parent.mkdir(parents=True, exist_ok=True)
9899
with zipfile.ZipFile(_normalize_path(dst_file), "w") as zip_file:
99100
for folder, subfolders, files in os.walk(src_dir):
100101
folder: Path = Path(folder)
@@ -161,7 +162,7 @@ def load_model(
161162
if substitute_values is not None:
162163
checkpoint = _substistute(checkpoint, substitute_values=substitute_values, substitute_keys=substitute_keys)
163164

164-
return module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint.get("metadata", None))
165+
return _load_state(cls=module_class, checkpoint=checkpoint, metadata=checkpoint.get("metadata", None))
165166
else:
166167
pylogger.warning(f"Loading a legacy checkpoint (from vanilla PyTorch Lightning): '{checkpoint_path}'")
167168
module_class.load_from_checkpoint(checkpoint_path=str(checkpoint_path), map_location=map_location)

0 commit comments

Comments
 (0)