Skip to content

Commit 31211a7

Browse files
authored
Refactor NNLogger and the callback NNTemplateCore (#5)
* Refactor NNLogger and the callback NNTemplateCore * Fix type hint
1 parent 6c5da82 commit 31211a7

File tree

2 files changed

+93
-79
lines changed

2 files changed

+93
-79
lines changed

src/nn_core/callbacks.py

Lines changed: 12 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,27 @@
1-
import dataclasses
21
import logging
3-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict
43

5-
import hydra
64
import pytorch_lightning as pl
7-
from omegaconf import DictConfig, OmegaConf
8-
from pytorch_lightning import Callback
9-
from pytorch_lightning.loggers import LightningLoggerBase
5+
from pytorch_lightning import Callback, Trainer
106

11-
from nn_core.common import PROJECT_ROOT
127
from nn_core.model_logging import NNLogger
138

149
pylogger = logging.getLogger(__name__)
1510

1611

17-
@dataclasses.dataclass
18-
class Upload:
19-
checkpoint: bool = True
20-
source: bool = True
21-
22-
23-
class NNLoggerConfiguration(Callback):
24-
def __init__(self, upload: Optional[Dict[str, bool]], logger: Optional[DictConfig], **kwargs):
25-
self.upload: Upload = Upload(**upload)
26-
self.logger_cfg = logger
27-
self.kwargs = kwargs
28-
29-
self.wandb: bool = self.logger_cfg["_target_"].endswith("WandbLogger")
12+
class NNTemplateCore(Callback):
13+
@staticmethod
14+
def _is_nnlogger(trainer: Trainer) -> bool:
15+
return isinstance(trainer.logger, NNLogger)
3016

3117
def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
32-
if isinstance(trainer.logger, NNLogger):
18+
if self._is_nnlogger(trainer):
19+
trainer.logger.upload_source()
3320
trainer.logger.log_configuration(model=pl_module)
34-
35-
if "wandb_watch" in self.kwargs:
36-
trainer.logger.wrapped.watch(pl_module, **self.kwargs["wandb_watch"])
21+
trainer.logger.watch_model(pl_module=pl_module)
3722

3823
def on_save_checkpoint(
39-
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
24+
self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
4025
) -> None:
41-
# Log to wandb the checkpoint meta information
42-
trainer.logger.add_path(obj_id="checkpoints/best", obj_path=trainer.checkpoint_callback.best_model_path)
43-
trainer.logger.add_path(
44-
obj_id="checkpoints/best_score",
45-
obj_path=str(trainer.checkpoint_callback.best_model_score.detach().cpu().item()),
46-
)
47-
48-
# Attach to each checkpoint saved the configuration and the wandb run path (to resume logging from
49-
# only the checkpoint)
50-
checkpoint["cfg"] = OmegaConf.to_container(trainer.logger.cfg, resolve=True)
51-
checkpoint[
52-
"run_path"
53-
] = f"{trainer.logger.experiment.entity}/{trainer.logger.experiment.project_name()}/{trainer.logger.version}"
54-
55-
# on_init_end can be employed since the Trainer doesn't use the logger until then.
56-
def on_init_end(self, trainer: "pl.Trainer") -> None:
57-
if self.logger_cfg is None:
58-
return
59-
60-
pylogger.info(f"Instantiating <{self.logger_cfg['_target_'].split('.')[-1]}>")
61-
62-
if trainer.fast_dev_run and self.wandb:
63-
# Switch wandb mode to offline to prevent online logging
64-
self.logger_cfg.mode = "offline"
65-
66-
logger: LightningLoggerBase = hydra.utils.instantiate(self.logger_cfg, version=trainer.logger.resume_id)
67-
68-
if self.upload.source:
69-
if self.wandb:
70-
logger.experiment.log_code(
71-
root=PROJECT_ROOT,
72-
name=None,
73-
include_fn=(
74-
lambda path: path.startswith(
75-
(
76-
str(PROJECT_ROOT / "conf"),
77-
str(PROJECT_ROOT / "src"),
78-
str(PROJECT_ROOT / "setup.cfg"),
79-
str(PROJECT_ROOT / "env.yaml"),
80-
)
81-
)
82-
and path.endswith((".py", ".yaml", ".yml", ".toml", ".cfg"))
83-
),
84-
)
85-
86-
trainer.logger.wrapped = logger
26+
if self._is_nnlogger(trainer):
27+
trainer.logger.on_save_checkpoint(trainer=trainer, pl_module=pl_module, checkpoint=checkpoint)

src/nn_core/model_logging.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,101 @@
11
import argparse
2+
import logging
23
import os
34
from pathlib import Path
45
from typing import Any, Dict, Optional, Union
56

7+
import hydra
68
import pytorch_lightning
79
from omegaconf import DictConfig, OmegaConf
10+
from pytorch_lightning import LightningModule, Trainer
11+
from pytorch_lightning.callbacks import ModelCheckpoint
812
from pytorch_lightning.loggers import LightningLoggerBase
913

14+
from nn_core.common import PROJECT_ROOT
15+
16+
pylogger = logging.getLogger(__name__)
17+
18+
1019
_STATS_KEY: str = "stats"
1120

1221

1322
class NNLogger(LightningLoggerBase):
1423

1524
__doc__ = LightningLoggerBase.__doc__
1625

17-
def __init__(self, logger: Optional[LightningLoggerBase], storage_dir: str, cfg: DictConfig, resume_id: str):
26+
def __init__(self, logging_cfg: DictConfig, cfg: DictConfig, resume_id: Optional[str]):
1827
super().__init__()
19-
self.wrapped: LightningLoggerBase = logger
20-
self.storage_dir: str = storage_dir
28+
self.logging_cfg = logging_cfg
2129
self.cfg = cfg
2230
self.resume_id = resume_id
2331

24-
def add_path(self, obj_id: str, obj_path: str) -> None:
25-
self.experiment.config.update({f"paths/{obj_id}": str(obj_path)}, allow_val_change=True)
32+
self.storage_dir: str = cfg.core.storage_dir
33+
self.wandb: bool = self.logging_cfg.logger["_target_"].endswith("WandbLogger")
34+
35+
if self.cfg.train.trainer.fast_dev_run and self.wandb:
36+
# Switch wandb mode to offline to prevent online logging
37+
pylogger.info("Setting the logger in 'offline' mode")
38+
self.logging_cfg.logger.mode = "offline"
39+
40+
pylogger.info(f"Instantiating <{self.logging_cfg.logger['_target_'].split('.')[-1]}>")
41+
self.wrapped: LightningLoggerBase = hydra.utils.instantiate(self.logging_cfg.logger, version=self.resume_id)
2642

27-
def __getattr__(self, item):
43+
# force experiment lazy initialization
44+
_ = self.wrapped.experiment
45+
46+
def __getattr__(self, item: str) -> Any:
2847
if self.wrapped is not None:
48+
pylogger.debug(f"Delegation with '__getattr__': {self.wrapped.__class__.__qualname__}.{item}")
2949
return getattr(self.wrapped, item)
3050

51+
def watch_model(self, pl_module: LightningModule):
52+
if self.wandb and "wandb_watch" in self.logging_cfg:
53+
pylogger.info("Starting to 'watch' the module")
54+
self.wrapped.watch(pl_module, **self.logging_cfg["wandb_watch"])
55+
56+
def upload_source(self) -> None:
57+
if self.logging_cfg.upload.source and self.wandb:
58+
pylogger.info("Uploading source code to wandb")
59+
self.wrapped.experiment.log_code(
60+
root=PROJECT_ROOT,
61+
name=None,
62+
include_fn=(
63+
lambda path: path.startswith(
64+
(
65+
str(PROJECT_ROOT / "conf"),
66+
str(PROJECT_ROOT / "src"),
67+
str(PROJECT_ROOT / "setup.cfg"),
68+
str(PROJECT_ROOT / "env.yaml"),
69+
)
70+
)
71+
and path.endswith((".py", ".yaml", ".yml", ".toml", ".cfg"))
72+
),
73+
)
74+
75+
def on_save_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> None:
76+
# Attach to each checkpoint saved the configuration and the wandb run path (to resume logging from
77+
# only the checkpoint)
78+
pylogger.debug("Attaching 'cfg' to the checkpoint")
79+
checkpoint["cfg"] = OmegaConf.to_container(trainer.logger.cfg, resolve=True)
80+
81+
pylogger.debug("Attaching 'run_path' to the checkpoint")
82+
checkpoint[
83+
"run_path"
84+
] = f"{trainer.logger.experiment.entity}/{trainer.logger.experiment.project_name()}/{trainer.logger.version}"
85+
86+
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
87+
# Log the checkpoint meta information
88+
self.add_path(obj_id="checkpoints/best", obj_path=checkpoint_callback.best_model_path)
89+
self.add_path(
90+
obj_id="checkpoints/best_score",
91+
obj_path=str(checkpoint_callback.best_model_score.detach().cpu().item()),
92+
)
93+
94+
def add_path(self, obj_id: str, obj_path: str) -> None:
95+
key = f"paths/{obj_id}"
96+
pylogger.debug(f"Logging '{key}'")
97+
self.experiment.config.update({key: str(obj_path)}, allow_val_change=True)
98+
3199
@property
32100
def save_dir(self) -> Optional[str]:
33101
return self.storage_dir
@@ -59,7 +127,8 @@ def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs):
59127
kwargs: Optional keywoard arguments, depends on the specific logger being used
60128
"""
61129
raise RuntimeError(
62-
"This method is called automatically by PyTorch Lightning if save_hyperparameters(logger=True) is called. The whole configuration is already logged by logger.log_configuration, set logger=False"
130+
"This method is called automatically by PyTorch Lightning if save_hyperparameters(logger=True) is called. "
131+
"The whole configuration is already logged by logger.log_configuration, set logger=False"
63132
)
64133

65134
def log_text(self, *args, **kwargs) -> None:
@@ -119,12 +188,16 @@ def log_configuration(
119188
yaml_conf: str = OmegaConf.to_yaml(cfg=cfg)
120189
run_dir: Path = Path(self.run_dir)
121190
run_dir.mkdir(exist_ok=True, parents=True)
122-
(run_dir / "config.yaml").write_text(yaml_conf)
191+
config_save_path = run_dir / "config.yaml"
192+
pylogger.debug(f"Saving the configuration in: {config_save_path}")
193+
config_save_path.write_text(yaml_conf)
123194

124195
# save number of model parameters
196+
pylogger.debug("Injecting model statistics in the 'cfg'")
125197
cfg[f"{_STATS_KEY}/params_total"] = sum(p.numel() for p in model.parameters())
126198
cfg[f"{_STATS_KEY}/params_trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
127199
cfg[f"{_STATS_KEY}/params_not_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
128200

129201
# send hparams to all loggers
202+
pylogger.debug("Logging 'cfg'")
130203
self.wrapped.log_hyperparams(cfg)

0 commit comments

Comments
 (0)