Skip to content

Commit f55551b

Browse files
authored
Merge pull request #26 from grok-ai/feature/bump-dependencies
Update to Lightning 1.7
2 parents 6ecc71f + 5fe4662 commit f55551b

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ packages=find:
1717
install_requires =
1818
# Add project specific dependencies
1919
# Stuff easy to break with updates
20-
pytorch-lightning>=1.5.8,<1.6
20+
pytorch-lightning==1.7.*
2121
hydra-core
2222
wandb
2323

src/nn_core/model_logging.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ def __init__(self, logging_cfg: DictConfig, cfg: DictConfig, resume_id: Optional
3939
self.logging_cfg.logger.mode = "offline"
4040

4141
pylogger.info(f"Instantiating <{self.logging_cfg.logger['_target_'].split('.')[-1]}>")
42-
self.wrapped: LightningLoggerBase = hydra.utils.instantiate(self.logging_cfg.logger, version=self.resume_id)
42+
self.wrapped: LightningLoggerBase = hydra.utils.instantiate(
43+
self.logging_cfg.logger,
44+
version=self.resume_id,
45+
dir=os.getenv("WANDB_DIR", "."),
46+
)
4347

4448
# force experiment lazy initialization
4549
_ = self.wrapped.experiment

src/nn_core/serialization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import pytorch_lightning as pl
1414
import torch
15+
from pytorch_lightning.core.saving import _load_state
1516
from pytorch_lightning.plugins import TorchCheckpointIO
1617

1718
METADATA_KEY: str = "metadata"
@@ -94,7 +95,7 @@ def remove_checkpoint(self, path) -> None:
9495

9596

9697
def compress_checkpoint(src_dir: Path, dst_file: Path, delete_dir: bool = True):
97-
98+
dst_file.parent.mkdir(parents=True, exist_ok=True)
9899
with zipfile.ZipFile(_normalize_path(dst_file), "w") as zip_file:
99100
for folder, subfolders, files in os.walk(src_dir):
100101
folder: Path = Path(folder)
@@ -161,7 +162,7 @@ def load_model(
161162
if substitute_values is not None:
162163
checkpoint = _substistute(checkpoint, substitute_values=substitute_values, substitute_keys=substitute_keys)
163164

164-
return module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint.get("metadata", None))
165+
return _load_state(cls=module_class, checkpoint=checkpoint, metadata=checkpoint.get("metadata", None))
165166
else:
166167
pylogger.warning(f"Loading a legacy checkpoint (from vanilla PyTorch Lightning): '{checkpoint_path}'")
167168
module_class.load_from_checkpoint(checkpoint_path=str(checkpoint_path), map_location=map_location)

0 commit comments

Comments
 (0)