|
1 | 1 | import argparse |
| 2 | +import logging |
2 | 3 | import os |
3 | 4 | from pathlib import Path |
4 | 5 | from typing import Any, Dict, Optional, Union |
5 | 6 |
|
| 7 | +import hydra |
6 | 8 | import pytorch_lightning |
7 | 9 | from omegaconf import DictConfig, OmegaConf |
| 10 | +from pytorch_lightning import LightningModule, Trainer |
| 11 | +from pytorch_lightning.callbacks import ModelCheckpoint |
8 | 12 | from pytorch_lightning.loggers import LightningLoggerBase |
9 | 13 |
|
| 14 | +from nn_core.common import PROJECT_ROOT |
| 15 | + |
| 16 | +pylogger = logging.getLogger(__name__) |
| 17 | + |
| 18 | + |
10 | 19 | _STATS_KEY: str = "stats" |
11 | 20 |
|
12 | 21 |
|
13 | 22 | class NNLogger(LightningLoggerBase): |
14 | 23 |
|
15 | 24 | __doc__ = LightningLoggerBase.__doc__ |
16 | 25 |
|
17 | | - def __init__(self, logger: Optional[LightningLoggerBase], storage_dir: str, cfg): |
| 26 | + def __init__(self, logging_cfg: DictConfig, cfg: DictConfig, resume_id: Optional[str]): |
18 | 27 | super().__init__() |
19 | | - self.wrapped: LightningLoggerBase = logger |
20 | | - self.storage_dir: str = storage_dir |
| 28 | + self.logging_cfg = logging_cfg |
21 | 29 | self.cfg = cfg |
| 30 | + self.resume_id = resume_id |
| 31 | + |
| 32 | + self.storage_dir: str = cfg.core.storage_dir |
| 33 | + self.wandb: bool = self.logging_cfg.logger["_target_"].endswith("WandbLogger") |
| 34 | + |
| 35 | + if self.cfg.train.trainer.fast_dev_run and self.wandb: |
| 36 | + # Switch wandb mode to offline to prevent online logging |
| 37 | + pylogger.info("Setting the logger in 'offline' mode") |
| 38 | + self.logging_cfg.logger.mode = "offline" |
22 | 39 |
|
23 | | - def __getattr__(self, item): |
| 40 | + pylogger.info(f"Instantiating <{self.logging_cfg.logger['_target_'].split('.')[-1]}>") |
| 41 | + self.wrapped: LightningLoggerBase = hydra.utils.instantiate(self.logging_cfg.logger, version=self.resume_id) |
| 42 | + |
| 43 | + # force experiment lazy initialization |
| 44 | + _ = self.wrapped.experiment |
| 45 | + |
| 46 | + def __getattr__(self, item: str) -> Any: |
24 | 47 | if self.wrapped is not None: |
| 48 | + pylogger.debug(f"Delegation with '__getattr__': {self.wrapped.__class__.__qualname__}.{item}") |
25 | 49 | return getattr(self.wrapped, item) |
26 | 50 |
|
| 51 | + def watch_model(self, pl_module: LightningModule): |
| 52 | + if self.wandb and "wandb_watch" in self.logging_cfg: |
| 53 | + pylogger.info("Starting to 'watch' the module") |
| 54 | + self.wrapped.watch(pl_module, **self.logging_cfg["wandb_watch"]) |
| 55 | + |
| 56 | + def upload_source(self) -> None: |
| 57 | + if self.logging_cfg.upload.source and self.wandb: |
| 58 | + pylogger.info("Uploading source code to wandb") |
| 59 | + self.wrapped.experiment.log_code( |
| 60 | + root=PROJECT_ROOT, |
| 61 | + name=None, |
| 62 | + include_fn=( |
| 63 | + lambda path: path.startswith( |
| 64 | + ( |
| 65 | + str(PROJECT_ROOT / "conf"), |
| 66 | + str(PROJECT_ROOT / "src"), |
| 67 | + str(PROJECT_ROOT / "setup.cfg"), |
| 68 | + str(PROJECT_ROOT / "env.yaml"), |
| 69 | + ) |
| 70 | + ) |
| 71 | + and path.endswith((".py", ".yaml", ".yml", ".toml", ".cfg")) |
| 72 | + ), |
| 73 | + ) |
| 74 | + |
| 75 | + def on_save_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> None: |
| 76 | + # Attach to each checkpoint saved the configuration and the wandb run path (to resume logging from |
| 77 | + # only the checkpoint) |
| 78 | + pylogger.debug("Attaching 'cfg' to the checkpoint") |
| 79 | + checkpoint["cfg"] = OmegaConf.to_container(trainer.logger.cfg, resolve=True) |
| 80 | + |
| 81 | + pylogger.debug("Attaching 'run_path' to the checkpoint") |
| 82 | + checkpoint[ |
| 83 | + "run_path" |
| 84 | + ] = f"{trainer.logger.experiment.entity}/{trainer.logger.experiment.project_name()}/{trainer.logger.version}" |
| 85 | + |
| 86 | + def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: |
| 87 | + # Log the checkpoint meta information |
| 88 | + self.add_path(obj_id="checkpoints/best", obj_path=checkpoint_callback.best_model_path) |
| 89 | + self.add_path( |
| 90 | + obj_id="checkpoints/best_score", |
| 91 | + obj_path=str(checkpoint_callback.best_model_score.detach().cpu().item()), |
| 92 | + ) |
| 93 | + |
| 94 | + def add_path(self, obj_id: str, obj_path: str) -> None: |
| 95 | + key = f"paths/{obj_id}" |
| 96 | + pylogger.debug(f"Logging '{key}'") |
| 97 | + self.experiment.config.update({key: str(obj_path)}, allow_val_change=True) |
| 98 | + |
27 | 99 | @property |
28 | 100 | def save_dir(self) -> Optional[str]: |
29 | 101 | return self.storage_dir |
@@ -55,7 +127,8 @@ def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): |
55 | 127 | kwargs: Optional keywoard arguments, depends on the specific logger being used |
56 | 128 | """ |
57 | 129 | raise RuntimeError( |
58 | | - "This method is called automatically by PyTorch Lightning if save_hyperparameters(logger=True) is called. The whole configuration is already logged by logger.log_configuration, set logger=False" |
| 130 | + "This method is called automatically by PyTorch Lightning if save_hyperparameters(logger=True) is called. " |
| 131 | + "The whole configuration is already logged by logger.log_configuration, set logger=False" |
59 | 132 | ) |
60 | 133 |
|
61 | 134 | def log_text(self, *args, **kwargs) -> None: |
@@ -115,12 +188,16 @@ def log_configuration( |
115 | 188 | yaml_conf: str = OmegaConf.to_yaml(cfg=cfg) |
116 | 189 | run_dir: Path = Path(self.run_dir) |
117 | 190 | run_dir.mkdir(exist_ok=True, parents=True) |
118 | | - (run_dir / "config.yaml").write_text(yaml_conf) |
| 191 | + config_save_path = run_dir / "config.yaml" |
| 192 | + pylogger.debug(f"Saving the configuration in: {config_save_path}") |
| 193 | + config_save_path.write_text(yaml_conf) |
119 | 194 |
|
120 | 195 | # save number of model parameters |
| 196 | + pylogger.debug("Injecting model statistics in the 'cfg'") |
121 | 197 | cfg[f"{_STATS_KEY}/params_total"] = sum(p.numel() for p in model.parameters()) |
122 | 198 | cfg[f"{_STATS_KEY}/params_trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) |
123 | 199 | cfg[f"{_STATS_KEY}/params_not_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) |
124 | 200 |
|
125 | 201 | # send hparams to all loggers |
| 202 | + pylogger.debug("Logging 'cfg'") |
126 | 203 | self.wrapped.log_hyperparams(cfg) |
0 commit comments