Skip to content

Commit a56bbcc

Browse files
authored
Merge pull request #13 from Flegyas/feature/stefan
Integrate feedback from user(s)
2 parents d5df104 + cc084e0 commit a56bbcc

File tree

6 files changed

+225
-28
lines changed

6 files changed

+225
-28
lines changed

src/nn_core/callbacks.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,31 @@
11
import logging
2-
from typing import Any, Dict
2+
from pathlib import Path
3+
from typing import Any, Dict, Optional
34

45
import pytorch_lightning as pl
6+
from omegaconf import DictConfig
57
from pytorch_lightning import Callback, Trainer
68

79
from nn_core.model_logging import NNLogger
10+
from nn_core.resume import parse_restore
11+
from nn_core.serialization import METADATA_KEY, NNCheckpointIO
812

913
pylogger = logging.getLogger(__name__)
1014

1115

1216
class NNTemplateCore(Callback):
17+
def __init__(self, restore_cfg: Optional[DictConfig]):
18+
self.resume_ckpt_path, self.resume_run_version = parse_restore(restore_cfg)
19+
self.restore_mode: Optional[str] = restore_cfg.get("mode", None) if restore_cfg is not None else None
20+
21+
@property
22+
def resume_id(self) -> Optional[str]:
23+
return self.resume_run_version
24+
25+
@property
26+
def trainer_ckpt_path(self) -> Optional[str]:
27+
return self.resume_ckpt_path if self.restore_mode != "finetune" else None
28+
1329
@staticmethod
1430
def _is_nnlogger(trainer: Trainer) -> bool:
1531
return isinstance(trainer.logger, NNLogger)
@@ -21,6 +37,12 @@ def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) ->
2137
trainer.logger.log_configuration(model=pl_module)
2238
trainer.logger.watch_model(pl_module=pl_module)
2339

40+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
41+
if self.restore_mode == "finetune":
42+
checkpoint = NNCheckpointIO.load(path=Path(self.resume_ckpt_path))
43+
44+
pl_module.load_state_dict(checkpoint["state_dict"])
45+
2446
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
2547
if self._is_nnlogger(trainer):
2648
trainer.logger: NNLogger
@@ -31,3 +53,4 @@ def on_save_checkpoint(
3153
) -> None:
3254
if self._is_nnlogger(trainer):
3355
trainer.logger.on_save_checkpoint(trainer=trainer, pl_module=pl_module, checkpoint=checkpoint)
56+
checkpoint[METADATA_KEY] = trainer.datamodule.metadata

src/nn_core/common/utils.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
from contextlib import contextmanager
34
from typing import List, Optional
45

56
import dotenv
@@ -53,6 +54,31 @@ def load_envs(env_file: Optional[str] = None) -> None:
5354
dotenv.load_dotenv(dotenv_path=env_file, override=True)
5455

5556

57+
@contextmanager
58+
def environ(**kwargs):
59+
"""Temporarily set the process environment variables.
60+
61+
https://stackoverflow.com/a/34333710
62+
63+
>>> with environ(PLUGINS_DIR=u'test/plugins'):
64+
... "PLUGINS_DIR" in os.environ
65+
True
66+
67+
>>> "PLUGINS_DIR" in os.environ
68+
False
69+
70+
:type kwargs: dict[str, unicode]
71+
:param kwargs: Environment variables to set
72+
"""
73+
old_environ = dict(os.environ)
74+
os.environ.update(kwargs)
75+
try:
76+
yield
77+
finally:
78+
os.environ.clear()
79+
os.environ.update(old_environ)
80+
81+
5682
def enforce_tags(tags: Optional[List[str]]) -> List[str]:
5783
if tags is None:
5884
if "id" in HydraConfig().cfg.hydra.job:
@@ -69,10 +95,10 @@ def enforce_tags(tags: Optional[List[str]]) -> List[str]:
6995
return tags
7096

7197

72-
def seed_index_everything(train_cfg: DictConfig) -> None:
98+
def seed_index_everything(train_cfg: DictConfig, sampling_seed: int = 42) -> None:
7399
if "seed_index" in train_cfg and train_cfg.seed_index is not None:
74100
seed_index = train_cfg.seed_index
75-
seed_everything(42)
101+
np.random.seed(sampling_seed)
76102
seeds = np.random.randint(np.iinfo(np.int32).max, size=max(42, seed_index + 1))
77103
seed = seeds[seed_index]
78104
seed_everything(seed)

src/nn_core/hooks.py

Whitespace-only changes.
File renamed without changes.

src/nn_core/resume.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,42 @@
11
import logging
22
import re
3+
import tempfile
34
from operator import xor
45
from pathlib import Path
56
from typing import Optional, Tuple
67

7-
import torch
88
import wandb
99
from omegaconf import DictConfig
1010
from wandb.apis.public import Run
1111

12+
from nn_core.common import PROJECT_ROOT
13+
from nn_core.common.utils import environ
14+
from nn_core.serialization import NNCheckpointIO
15+
1216
pylogger = logging.getLogger(__name__)
1317

1418
RUN_PATH_PATTERN = re.compile(r"^([^/]+)/([^/]+)/([^/]+)$")
1519

1620
RESUME_MODES = {
17-
"continue": {
18-
"restore_model": True,
19-
"restore_run": True,
21+
None: {
22+
"logging": False,
23+
"trainer": False,
24+
"weights": False,
25+
},
26+
"finetune": {
27+
"logging": False,
28+
"trainer": False,
29+
"weights": True,
2030
},
2131
"hotstart": {
22-
"restore_model": True,
23-
"restore_run": False,
32+
"logging": False,
33+
"trainer": True,
34+
"weights": True,
2435
},
25-
None: {
26-
"restore_model": False,
27-
"restore_run": False,
36+
"continue": {
37+
"logging": True,
38+
"trainer": True,
39+
"weights": True,
2840
},
2941
}
3042

@@ -38,16 +50,26 @@ def resolve_ckpt(ckpt_or_run_path: str) -> str:
3850
Returns:
3951
an existing path towards the best checkpoint
4052
"""
41-
if Path(ckpt_or_run_path).exists():
42-
return ckpt_or_run_path
43-
44-
try:
45-
api = wandb.Api()
46-
run: Run = api.run(path=ckpt_or_run_path)
47-
ckpt_or_run_path = run.config["paths/checkpoints/best"]
53+
if RUN_PATH_PATTERN.match(ckpt_or_run_path):
54+
# If WANDB_DIR is set (as it is the case with our hydra configuration), the run dir is created by wandb in the
55+
# project's root folder instead of in a temp directory.
56+
with tempfile.TemporaryDirectory() as tmp_dir, environ(WANDB_DIR=tmp_dir):
57+
# We are resolving the path from a wandb run id
58+
try:
59+
api = wandb.Api()
60+
run: Run = api.run(path=ckpt_or_run_path)
61+
ckpt_or_run_path = run.config["paths/checkpoints/best"]
62+
return ckpt_or_run_path
63+
except wandb.errors.CommError:
64+
raise ValueError(f"Checkpoint or run not found: {ckpt_or_run_path}")
65+
66+
_ckpt_or_run_path: Path = Path(ckpt_or_run_path)
67+
# If the path is relative, it is wrt the PROJECT_ROOT, so it is prepended.
68+
if not _ckpt_or_run_path.is_absolute():
69+
_ckpt_or_run_path = PROJECT_ROOT / _ckpt_or_run_path
70+
71+
if _ckpt_or_run_path.exists():
4872
return ckpt_or_run_path
49-
except wandb.errors.CommError:
50-
raise ValueError(f"Checkpoint or run not found: {ckpt_or_run_path}")
5173

5274

5375
def resolve_run_path(ckpt_or_run_path: str) -> str:
@@ -63,7 +85,7 @@ def resolve_run_path(ckpt_or_run_path: str) -> str:
6385
return ckpt_or_run_path
6486

6587
try:
66-
return torch.load(ckpt_or_run_path)["run_path"]
88+
return NNCheckpointIO.load(path=Path(ckpt_or_run_path))["run_path"]
6789
except FileNotFoundError:
6890
raise ValueError(f"Checkpoint or run not found: {ckpt_or_run_path}")
6991

@@ -83,7 +105,11 @@ def resolve_run_version(ckpt_or_run_path: Optional[str] = None, run_path: Option
83105
return RUN_PATH_PATTERN.match(run_path).group(3)
84106

85107

108+
# TODO: Refactor returning type to include restore mode too.
86109
def parse_restore(restore_cfg: DictConfig) -> Tuple[Optional[str], Optional[str]]:
110+
if restore_cfg is None:
111+
return None, None
112+
87113
ckpt_or_run_path = restore_cfg.ckpt_or_run_path
88114
resume_mode = restore_cfg.mode
89115

@@ -92,22 +118,21 @@ def parse_restore(restore_cfg: DictConfig) -> Tuple[Optional[str], Optional[str]
92118

93119
if xor(bool(ckpt_or_run_path), bool(resume_mode)):
94120
pylogger.warning(f"Inconsistent resume modality {resume_mode} and checkpoint path '{ckpt_or_run_path}'")
121+
else:
122+
pylogger.info(f"Restoring with mode: <{resume_mode}>")
95123

96124
if resume_mode not in RESUME_MODES:
97125
message = f"Unsupported resume mode {resume_mode}. Available resume modes are: {RESUME_MODES}"
98126
pylogger.error(message)
99127
raise ValueError(message)
100128

101129
flags = RESUME_MODES[resume_mode]
102-
restore_model = flags["restore_model"]
103-
restore_run = flags["restore_run"]
104130

105131
if ckpt_or_run_path is not None:
106-
if restore_model:
107-
resume_ckpt_path = resolve_ckpt(ckpt_or_run_path)
108-
pylogger.info(f"Resume training from: '{resume_ckpt_path}'")
132+
resume_ckpt_path = resolve_ckpt(ckpt_or_run_path)
133+
pylogger.info(f"Resolved checkpoint path: '{resume_ckpt_path}'")
109134

110-
if restore_run:
135+
if flags["logging"]:
111136
run_path = resolve_run_path(ckpt_or_run_path)
112137
resume_run_version = resolve_run_version(run_path=run_path)
113138
pylogger.info(f"Resume logging to: '{run_path}'")

src/nn_core/serialization.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import importlib
2+
import inspect
3+
import logging
4+
import os
5+
import shutil
6+
import tempfile
7+
import zipfile
8+
from contextlib import contextmanager
9+
from pathlib import Path
10+
from typing import Any, Callable, Dict, Optional, Type, Union
11+
12+
import pytorch_lightning as pl
13+
from pytorch_lightning.plugins import TorchCheckpointIO
14+
15+
METADATA_KEY: str = "metadata"
16+
17+
pylogger = logging.getLogger(__name__)
18+
19+
20+
_METADATA_MODULE_KEY = f"{METADATA_KEY}_module"
21+
_METADATA_CLASS_KEY = f"{METADATA_KEY}_class"
22+
23+
24+
def _normalize_path(path: Union[Path, str]) -> Path:
25+
if isinstance(path, str):
26+
path = Path(path)
27+
return (path.parent / path.stem.split(".")[0]).with_suffix(".ckpt.zip")
28+
29+
30+
class NNCheckpointIO(TorchCheckpointIO):
31+
def __init__(self, jailing_dir: Optional[str] = None):
32+
self.jailing_dir = jailing_dir
33+
34+
@classmethod
35+
def load(cls, path: Path, map_location: Optional[Callable] = lambda storage, loc: storage):
36+
return cls().load_checkpoint(path=str(path), map_location=map_location)
37+
38+
def save_checkpoint(self, checkpoint: Dict[str, Any], path, storage_options: Optional[Any] = None) -> None:
39+
checkpoint_dir: Path = _normalize_path(path=path)
40+
41+
with tempfile.TemporaryDirectory() as tmp_dir:
42+
tmp_dir = Path(tmp_dir)
43+
44+
if METADATA_KEY in checkpoint:
45+
metadata = checkpoint[METADATA_KEY]
46+
47+
metadata_path: Path = tmp_dir / METADATA_KEY
48+
metadata_path.mkdir(exist_ok=True, parents=True)
49+
metadata.save(dst_path=metadata_path)
50+
51+
checkpoint[_METADATA_MODULE_KEY] = inspect.getmodule(metadata).__name__
52+
checkpoint[_METADATA_CLASS_KEY] = type(metadata).__name__
53+
54+
del checkpoint[METADATA_KEY]
55+
56+
super().save_checkpoint(
57+
checkpoint=checkpoint, path=tmp_dir / "checkpoint.ckpt", storage_options=storage_options
58+
)
59+
60+
compress_checkpoint(src_dir=tmp_dir, dst_file=checkpoint_dir)
61+
62+
def load_checkpoint(self, path, map_location: Optional[Callable] = lambda storage, loc: storage) -> Dict[str, Any]:
63+
# load_checkpoint called from Trainer/Callbacks
64+
with extract_checkpoint(ckpt_file=Path(path)) as ckpt_dir:
65+
checkpoint = super().load_checkpoint(path=ckpt_dir / "checkpoint.ckpt", map_location=map_location)
66+
67+
if _METADATA_MODULE_KEY in checkpoint:
68+
metadata_path: Path = ckpt_dir / METADATA_KEY
69+
if metadata_path.exists():
70+
metadata_module = importlib.import_module(checkpoint[_METADATA_MODULE_KEY])
71+
metadata = getattr(metadata_module, checkpoint[_METADATA_CLASS_KEY]).load(src_path=metadata_path)
72+
checkpoint[METADATA_KEY] = metadata
73+
del checkpoint[_METADATA_MODULE_KEY]
74+
del checkpoint[_METADATA_CLASS_KEY]
75+
else:
76+
raise FileNotFoundError(
77+
"This checkpoint is corrupted. It appears data info is required but missing."
78+
)
79+
80+
return checkpoint
81+
82+
def remove_checkpoint(self, path) -> None:
83+
if self.jailing_dir is None or path.startswith(self.jailing_dir):
84+
_normalize_path(path).unlink()
85+
pylogger.debug(f"Removing checkpoint from {path}")
86+
else:
87+
pylogger.debug(
88+
"Ignoring checkpoint deletion since it pertains to another run: "
89+
"https://github.com/PyTorchLightning/pytorch-lightning/issues/11379"
90+
)
91+
92+
93+
def compress_checkpoint(src_dir: Path, dst_file: Path, delete_dir: bool = True):
94+
95+
with zipfile.ZipFile(_normalize_path(dst_file), "w") as zip_file:
96+
for folder, subfolders, files in os.walk(src_dir):
97+
folder: Path = Path(folder)
98+
for file in files:
99+
zip_file.write(
100+
folder / file,
101+
os.path.relpath(os.path.join(folder, file), src_dir),
102+
compress_type=zipfile.ZIP_DEFLATED,
103+
)
104+
105+
if delete_dir:
106+
pylogger.debug(f"Deleting the checkpoint folder: '{src_dir}'")
107+
shutil.rmtree(path=src_dir, ignore_errors=True)
108+
109+
110+
@contextmanager
111+
def extract_checkpoint(ckpt_file: Path) -> Path:
112+
with tempfile.TemporaryDirectory() as tmp_dir:
113+
pylogger.debug(f"Extracting archive file '{ckpt_file}' to temp dir '{tmp_dir}'")
114+
with zipfile.ZipFile(_normalize_path(ckpt_file), "r") as compressed_ckpt:
115+
compressed_ckpt.extractall(tmp_dir)
116+
yield Path(tmp_dir)
117+
118+
119+
def load_model(module_class: Type[pl.LightningModule], checkpoint_path: Path):
120+
checkpoint = NNCheckpointIO.load(path=checkpoint_path)
121+
122+
model = module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint["metadata"])
123+
return model

0 commit comments

Comments
 (0)