Skip to content

Commit 3f7dabe

Browse files
authored
Merge pull request #7 from Flegyas/develop
Release version 0.0.3
2 parents 7e55875 + 31211a7 commit 3f7dabe

File tree

4 files changed

+159
-95
lines changed

4 files changed

+159
-95
lines changed

src/nn_core/callbacks.py

Lines changed: 13 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,27 @@
1-
import dataclasses
21
import logging
3-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict
43

5-
import hydra
64
import pytorch_lightning as pl
7-
from omegaconf import DictConfig
8-
from pytorch_lightning import Callback
9-
from pytorch_lightning.loggers import LightningLoggerBase
5+
from pytorch_lightning import Callback, Trainer
106

11-
from nn_core.common import PROJECT_ROOT
127
from nn_core.model_logging import NNLogger
138

149
pylogger = logging.getLogger(__name__)
1510

1611

17-
@dataclasses.dataclass
18-
class Upload:
19-
checkpoint: bool = True
20-
source: bool = True
21-
22-
23-
class NNLoggerConfiguration(Callback):
24-
def __init__(self, upload: Optional[Dict[str, bool]], logger: Optional[DictConfig], **kwargs):
25-
self.upload: Upload = Upload(**upload)
26-
self.logger_cfg = logger
27-
self.kwargs = kwargs
28-
29-
self.wandb: bool = self.logger_cfg["_target_"].endswith("WandbLogger")
12+
class NNTemplateCore(Callback):
13+
@staticmethod
14+
def _is_nnlogger(trainer: Trainer) -> bool:
15+
return isinstance(trainer.logger, NNLogger)
3016

3117
def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
32-
if isinstance(trainer.logger, NNLogger):
18+
if self._is_nnlogger(trainer):
19+
trainer.logger.upload_source()
3320
trainer.logger.log_configuration(model=pl_module)
34-
35-
if "wandb_watch" in self.kwargs:
36-
trainer.logger.wrapped.watch(pl_module, **self.kwargs["wandb_watch"])
21+
trainer.logger.watch_model(pl_module=pl_module)
3722

3823
def on_save_checkpoint(
39-
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
40-
) -> dict:
41-
data = [
42-
("best_model_path", trainer.checkpoint_callback.best_model_path),
43-
("best_model_score", str(trainer.checkpoint_callback.best_model_score.detach().cpu().item())),
44-
]
45-
trainer.logger.log_text(key="storage_info", columns=["key", "value"], data=data)
46-
47-
return checkpoint
48-
49-
# on_init_end can be employed since the Trainer doesn't use the logger until then.
50-
def on_init_end(self, trainer: "pl.Trainer") -> None:
51-
if self.logger_cfg is None:
52-
return
53-
54-
pylogger.info(f"Instantiating <{self.logger_cfg['_target_'].split('.')[-1]}>")
55-
56-
if trainer.fast_dev_run and self.wandb:
57-
# Switch wandb mode to offline to prevent online logging
58-
self.logger_cfg.mode = "offline"
59-
60-
logger: LightningLoggerBase = hydra.utils.instantiate(self.logger_cfg)
61-
62-
if self.upload.source:
63-
if self.wandb:
64-
logger.experiment.log_code(
65-
root=PROJECT_ROOT,
66-
name=None,
67-
include_fn=(
68-
lambda path: path.startswith(
69-
(
70-
str(PROJECT_ROOT / "conf"),
71-
str(PROJECT_ROOT / "src"),
72-
str(PROJECT_ROOT / "setup.cfg"),
73-
str(PROJECT_ROOT / "env.yaml"),
74-
)
75-
)
76-
and path.endswith((".py", ".yaml", ".yml", ".toml", ".cfg"))
77-
),
78-
)
79-
80-
trainer.logger.wrapped = logger
24+
self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
25+
) -> None:
26+
if self._is_nnlogger(trainer):
27+
trainer.logger.on_save_checkpoint(trainer=trainer, pl_module=pl_module, checkpoint=checkpoint)

src/nn_core/hooks.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +0,0 @@
1-
from typing import Callable, Dict, NoReturn
2-
3-
from omegaconf import DictConfig, OmegaConf
4-
5-
6-
class OnSaveCheckpointInjection:
7-
def __init__(
8-
self,
9-
cfg: DictConfig,
10-
on_save_checkpoint: Callable[[Dict], NoReturn],
11-
):
12-
"""Inject the configuration into the checkpoint monkey patching the on_save_checkpoint hook.
13-
14-
Args:
15-
cfg: the configuration to inject
16-
on_save_checkpoint: the on_save_checkpoint to monkey patch
17-
"""
18-
self.cfg = cfg
19-
self.on_save_checkpoint = on_save_checkpoint
20-
21-
def __call__(self, checkpoint: Dict) -> None:
22-
self.on_save_checkpoint(checkpoint)
23-
checkpoint["cfg"] = OmegaConf.to_container(self.cfg, resolve=True)

src/nn_core/model_logging.py

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,101 @@
11
import argparse
2+
import logging
23
import os
34
from pathlib import Path
45
from typing import Any, Dict, Optional, Union
56

7+
import hydra
68
import pytorch_lightning
79
from omegaconf import DictConfig, OmegaConf
10+
from pytorch_lightning import LightningModule, Trainer
11+
from pytorch_lightning.callbacks import ModelCheckpoint
812
from pytorch_lightning.loggers import LightningLoggerBase
913

14+
from nn_core.common import PROJECT_ROOT
15+
16+
pylogger = logging.getLogger(__name__)
17+
18+
1019
_STATS_KEY: str = "stats"
1120

1221

1322
class NNLogger(LightningLoggerBase):
1423

1524
__doc__ = LightningLoggerBase.__doc__
1625

17-
def __init__(self, logger: Optional[LightningLoggerBase], storage_dir: str, cfg):
26+
def __init__(self, logging_cfg: DictConfig, cfg: DictConfig, resume_id: Optional[str]):
1827
super().__init__()
19-
self.wrapped: LightningLoggerBase = logger
20-
self.storage_dir: str = storage_dir
28+
self.logging_cfg = logging_cfg
2129
self.cfg = cfg
30+
self.resume_id = resume_id
31+
32+
self.storage_dir: str = cfg.core.storage_dir
33+
self.wandb: bool = self.logging_cfg.logger["_target_"].endswith("WandbLogger")
34+
35+
if self.cfg.train.trainer.fast_dev_run and self.wandb:
36+
# Switch wandb mode to offline to prevent online logging
37+
pylogger.info("Setting the logger in 'offline' mode")
38+
self.logging_cfg.logger.mode = "offline"
2239

23-
def __getattr__(self, item):
40+
pylogger.info(f"Instantiating <{self.logging_cfg.logger['_target_'].split('.')[-1]}>")
41+
self.wrapped: LightningLoggerBase = hydra.utils.instantiate(self.logging_cfg.logger, version=self.resume_id)
42+
43+
# force experiment lazy initialization
44+
_ = self.wrapped.experiment
45+
46+
def __getattr__(self, item: str) -> Any:
2447
if self.wrapped is not None:
48+
pylogger.debug(f"Delegation with '__getattr__': {self.wrapped.__class__.__qualname__}.{item}")
2549
return getattr(self.wrapped, item)
2650

51+
def watch_model(self, pl_module: LightningModule):
52+
if self.wandb and "wandb_watch" in self.logging_cfg:
53+
pylogger.info("Starting to 'watch' the module")
54+
self.wrapped.watch(pl_module, **self.logging_cfg["wandb_watch"])
55+
56+
def upload_source(self) -> None:
57+
if self.logging_cfg.upload.source and self.wandb:
58+
pylogger.info("Uploading source code to wandb")
59+
self.wrapped.experiment.log_code(
60+
root=PROJECT_ROOT,
61+
name=None,
62+
include_fn=(
63+
lambda path: path.startswith(
64+
(
65+
str(PROJECT_ROOT / "conf"),
66+
str(PROJECT_ROOT / "src"),
67+
str(PROJECT_ROOT / "setup.cfg"),
68+
str(PROJECT_ROOT / "env.yaml"),
69+
)
70+
)
71+
and path.endswith((".py", ".yaml", ".yml", ".toml", ".cfg"))
72+
),
73+
)
74+
75+
def on_save_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> None:
76+
# Attach to each checkpoint saved the configuration and the wandb run path (to resume logging from
77+
# only the checkpoint)
78+
pylogger.debug("Attaching 'cfg' to the checkpoint")
79+
checkpoint["cfg"] = OmegaConf.to_container(trainer.logger.cfg, resolve=True)
80+
81+
pylogger.debug("Attaching 'run_path' to the checkpoint")
82+
checkpoint[
83+
"run_path"
84+
] = f"{trainer.logger.experiment.entity}/{trainer.logger.experiment.project_name()}/{trainer.logger.version}"
85+
86+
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
87+
# Log the checkpoint meta information
88+
self.add_path(obj_id="checkpoints/best", obj_path=checkpoint_callback.best_model_path)
89+
self.add_path(
90+
obj_id="checkpoints/best_score",
91+
obj_path=str(checkpoint_callback.best_model_score.detach().cpu().item()),
92+
)
93+
94+
def add_path(self, obj_id: str, obj_path: str) -> None:
95+
key = f"paths/{obj_id}"
96+
pylogger.debug(f"Logging '{key}'")
97+
self.experiment.config.update({key: str(obj_path)}, allow_val_change=True)
98+
2799
@property
28100
def save_dir(self) -> Optional[str]:
29101
return self.storage_dir
@@ -55,7 +127,8 @@ def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs):
55127
kwargs: Optional keywoard arguments, depends on the specific logger being used
56128
"""
57129
raise RuntimeError(
58-
"This method is called automatically by PyTorch Lightning if save_hyperparameters(logger=True) is called. The whole configuration is already logged by logger.log_configuration, set logger=False"
130+
"This method is called automatically by PyTorch Lightning if save_hyperparameters(logger=True) is called. "
131+
"The whole configuration is already logged by logger.log_configuration, set logger=False"
59132
)
60133

61134
def log_text(self, *args, **kwargs) -> None:
@@ -115,12 +188,16 @@ def log_configuration(
115188
yaml_conf: str = OmegaConf.to_yaml(cfg=cfg)
116189
run_dir: Path = Path(self.run_dir)
117190
run_dir.mkdir(exist_ok=True, parents=True)
118-
(run_dir / "config.yaml").write_text(yaml_conf)
191+
config_save_path = run_dir / "config.yaml"
192+
pylogger.debug(f"Saving the configuration in: {config_save_path}")
193+
config_save_path.write_text(yaml_conf)
119194

120195
# save number of model parameters
196+
pylogger.debug("Injecting model statistics in the 'cfg'")
121197
cfg[f"{_STATS_KEY}/params_total"] = sum(p.numel() for p in model.parameters())
122198
cfg[f"{_STATS_KEY}/params_trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
123199
cfg[f"{_STATS_KEY}/params_not_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
124200

125201
# send hparams to all loggers
202+
pylogger.debug("Logging 'cfg'")
126203
self.wrapped.log_hyperparams(cfg)

src/nn_core/resume.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import re
2+
from pathlib import Path
3+
from typing import Optional
4+
5+
import torch
6+
import wandb
7+
from wandb.apis.public import Run
8+
9+
RUN_PATH_PATTERN = re.compile(r"^([^/]+)/([^/]+)/([^/]+)$")
10+
11+
12+
def resolve_ckpt(ckpt_or_run_path: str) -> str:
13+
"""Resolve the run path or ckpt to a checkpoint.
14+
15+
Args:
16+
ckpt_or_run_path: run identifier or checkpoint path
17+
18+
Returns:
19+
an existing path towards the best checkpoint
20+
"""
21+
if Path(ckpt_or_run_path).exists():
22+
return ckpt_or_run_path
23+
24+
try:
25+
api = wandb.Api()
26+
run: Run = api.run(path=ckpt_or_run_path)
27+
ckpt_or_run_path = run.config["paths/checkpoints/best"]
28+
return ckpt_or_run_path
29+
except wandb.errors.CommError:
30+
raise ValueError(f"Checkpoint or run not found: {ckpt_or_run_path}")
31+
32+
33+
def resolve_run_path(ckpt_or_run_path: str) -> str:
34+
"""Resolve the run path or ckpt to a run path.
35+
36+
Args:
37+
ckpt_or_run_path: run identifier or checkpoint path
38+
39+
Returns:
40+
an wandb run path identifier
41+
"""
42+
if RUN_PATH_PATTERN.match(ckpt_or_run_path):
43+
return ckpt_or_run_path
44+
45+
try:
46+
return torch.load(ckpt_or_run_path)["run_path"]
47+
except FileNotFoundError:
48+
raise ValueError(f"Checkpoint or run not found: {ckpt_or_run_path}")
49+
50+
51+
def resolve_run_version(ckpt_or_run_path: Optional[str] = None, run_path: Optional[str] = None) -> str:
52+
"""Resolve the run path or ckpt to the wandb run version.
53+
54+
Args:
55+
ckpt_or_run_path: run identifier or checkpoint path
56+
run_path: the run path if already available
57+
58+
Returns:
59+
a wandb run version
60+
"""
61+
if run_path is None:
62+
run_path = resolve_run_path(ckpt_or_run_path)
63+
return RUN_PATH_PATTERN.match(run_path).group(3)

0 commit comments

Comments
 (0)