Skip to content

Commit d5df104

Browse files
authored
Merge pull request #12 from Flegyas/develop
Version 0.0.4
2 parents 3f7dabe + ba8be70 commit d5df104

File tree

5 files changed

+219
-6
lines changed

5 files changed

+219
-6
lines changed

src/nn_core/callbacks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,16 @@ def _is_nnlogger(trainer: Trainer) -> bool:
1616

1717
def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
1818
if self._is_nnlogger(trainer):
19+
trainer.logger: NNLogger
1920
trainer.logger.upload_source()
2021
trainer.logger.log_configuration(model=pl_module)
2122
trainer.logger.watch_model(pl_module=pl_module)
2223

24+
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
25+
if self._is_nnlogger(trainer):
26+
trainer.logger: NNLogger
27+
trainer.logger.upload_run_files()
28+
2329
def on_save_checkpoint(
2430
self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
2531
) -> None:

src/nn_core/common/utils.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import logging
22
import os
3-
from typing import Optional
3+
from typing import List, Optional
44

55
import dotenv
6+
import numpy as np
7+
from hydra.core.hydra_config import HydraConfig
8+
from omegaconf import DictConfig
9+
from pytorch_lightning import seed_everything
10+
from rich.prompt import Prompt
611

7-
logger = logging.getLogger(__name__)
12+
pylogger = logging.getLogger(__name__)
813

914

1015
def get_env(env_name: str, default: Optional[str] = None) -> str:
@@ -19,13 +24,17 @@ def get_env(env_name: str, default: Optional[str] = None) -> str:
1924
"""
2025
if env_name not in os.environ:
2126
if default is None:
22-
raise KeyError(f"{env_name} not defined and no default value is present!")
27+
message = f"{env_name} not defined and no default value is present!"
28+
pylogger.error(message)
29+
raise KeyError(message)
2330
return default
2431

2532
env_value: str = os.environ[env_name]
2633
if not env_value:
2734
if default is None:
28-
raise ValueError(f"{env_name} has yet to be configured and no default value is present!")
35+
message = f"{env_name} has yet to be configured and no default value is present!"
36+
pylogger.error(message)
37+
raise ValueError(message)
2938
return default
3039

3140
return env_value
@@ -42,3 +51,31 @@ def load_envs(env_file: Optional[str] = None) -> None:
4251
it searches for a `.env` file in the project.
4352
"""
4453
dotenv.load_dotenv(dotenv_path=env_file, override=True)
54+
55+
56+
def enforce_tags(tags: Optional[List[str]]) -> List[str]:
57+
if tags is None:
58+
if "id" in HydraConfig().cfg.hydra.job:
59+
# We are in multi-run setting (either via a sweep or a scheduler)
60+
message: str = "You need to specify 'core.tags' in a multi-run setting!"
61+
pylogger.error(message)
62+
raise ValueError(message)
63+
64+
pylogger.warning("No tags provided, asking for tags...")
65+
tags = Prompt.ask("Enter a list of comma separated tags", default="develop")
66+
tags = [x.strip() for x in tags.split(",")]
67+
68+
pylogger.info(f"Tags: {tags if tags is not None else []}")
69+
return tags
70+
71+
72+
def seed_index_everything(train_cfg: DictConfig) -> None:
73+
if "seed_index" in train_cfg and train_cfg.seed_index is not None:
74+
seed_index = train_cfg.seed_index
75+
seed_everything(42)
76+
seeds = np.random.randint(np.iinfo(np.int32).max, size=max(42, seed_index + 1))
77+
seed = seeds[seed_index]
78+
seed_everything(seed)
79+
pylogger.info(f"Setting seed {seed} from seeds[{seed_index}]")
80+
else:
81+
pylogger.warning("The seed has not been set! The reproducibility is not guaranteed.")

src/nn_core/model_logging.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import logging
33
import os
4+
import shutil
45
from pathlib import Path
56
from typing import Any, Dict, Optional, Union
67

@@ -55,7 +56,7 @@ def watch_model(self, pl_module: LightningModule):
5556

5657
def upload_source(self) -> None:
5758
if self.logging_cfg.upload.source and self.wandb:
58-
pylogger.info("Uploading source code to wandb")
59+
pylogger.info("Uploading source code to W&B")
5960
self.wrapped.experiment.log_code(
6061
root=PROJECT_ROOT,
6162
name=None,
@@ -201,3 +202,12 @@ def log_configuration(
201202
# send hparams to all loggers
202203
pylogger.debug("Logging 'cfg'")
203204
self.wrapped.log_hyperparams(cfg)
205+
206+
def upload_run_files(self):
207+
if self.logging_cfg.upload.run_files:
208+
if self.wandb:
209+
pylogger.info("Uploading run files to W&B")
210+
shutil.copytree(self.run_dir, f"{self.wrapped.experiment.dir}/run_files")
211+
212+
# FIXME: symlink not working for some reason
213+
# os.symlink(self.run_dir, f"{self.wrapped.experiment.dir}/run_files", target_is_directory=True)

src/nn_core/resume.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,33 @@
1+
import logging
12
import re
3+
from operator import xor
24
from pathlib import Path
3-
from typing import Optional
5+
from typing import Optional, Tuple
46

57
import torch
68
import wandb
9+
from omegaconf import DictConfig
710
from wandb.apis.public import Run
811

12+
pylogger = logging.getLogger(__name__)
13+
914
RUN_PATH_PATTERN = re.compile(r"^([^/]+)/([^/]+)/([^/]+)$")
1015

16+
RESUME_MODES = {
17+
"continue": {
18+
"restore_model": True,
19+
"restore_run": True,
20+
},
21+
"hotstart": {
22+
"restore_model": True,
23+
"restore_run": False,
24+
},
25+
None: {
26+
"restore_model": False,
27+
"restore_run": False,
28+
},
29+
}
30+
1131

1232
def resolve_ckpt(ckpt_or_run_path: str) -> str:
1333
"""Resolve the run path or ckpt to a checkpoint.
@@ -61,3 +81,35 @@ def resolve_run_version(ckpt_or_run_path: Optional[str] = None, run_path: Option
6181
if run_path is None:
6282
run_path = resolve_run_path(ckpt_or_run_path)
6383
return RUN_PATH_PATTERN.match(run_path).group(3)
84+
85+
86+
def parse_restore(restore_cfg: DictConfig) -> Tuple[Optional[str], Optional[str]]:
87+
ckpt_or_run_path = restore_cfg.ckpt_or_run_path
88+
resume_mode = restore_cfg.mode
89+
90+
resume_ckpt_path = None
91+
resume_run_version = None
92+
93+
if xor(bool(ckpt_or_run_path), bool(resume_mode)):
94+
pylogger.warning(f"Inconsistent resume modality {resume_mode} and checkpoint path '{ckpt_or_run_path}'")
95+
96+
if resume_mode not in RESUME_MODES:
97+
message = f"Unsupported resume mode {resume_mode}. Available resume modes are: {RESUME_MODES}"
98+
pylogger.error(message)
99+
raise ValueError(message)
100+
101+
flags = RESUME_MODES[resume_mode]
102+
restore_model = flags["restore_model"]
103+
restore_run = flags["restore_run"]
104+
105+
if ckpt_or_run_path is not None:
106+
if restore_model:
107+
resume_ckpt_path = resolve_ckpt(ckpt_or_run_path)
108+
pylogger.info(f"Resume training from: '{resume_ckpt_path}'")
109+
110+
if restore_run:
111+
run_path = resolve_run_path(ckpt_or_run_path)
112+
resume_run_version = resolve_run_version(run_path=run_path)
113+
pylogger.info(f"Resume logging to: '{run_path}'")
114+
115+
return resume_ckpt_path, resume_run_version

src/nn_core/ui.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import datetime
2+
import operator
3+
from pathlib import Path
4+
from typing import List
5+
6+
import hydra
7+
import omegaconf
8+
import streamlit as st
9+
import wandb
10+
from hydra.core.global_hydra import GlobalHydra
11+
from hydra.experimental import compose
12+
from stqdm import stqdm
13+
14+
from nn_core.common import PROJECT_ROOT
15+
16+
WANDB_DIR: Path = PROJECT_ROOT / "wandb"
17+
WANDB_DIR.mkdir(exist_ok=True, parents=True)
18+
19+
st_run_sel = st.sidebar
20+
21+
22+
def local_checkpoint_selection(run_dir: Path, st_key: str) -> Path:
23+
checkpoint_paths: List[Path] = list(run_dir.rglob("checkpoints/*"))
24+
if len(checkpoint_paths) == 0:
25+
st.error(f"There's no checkpoint under {run_dir}! Are you sure the restore was successful?")
26+
st.stop()
27+
checkpoint_path: Path = st_run_sel.selectbox(
28+
label="Select a checkpoint",
29+
index=0,
30+
options=checkpoint_paths,
31+
format_func=operator.attrgetter("name"),
32+
key=f"checkpoint_select_{st_key}",
33+
)
34+
35+
return checkpoint_path
36+
37+
38+
def get_run_dir(entity: str, project: str, run_id: str) -> Path:
39+
"""Get run directory.
40+
41+
:param run_path: "entity/project/run_id"
42+
:return:
43+
"""
44+
api = wandb.Api()
45+
run = api.run(path=f"{entity}/{project}/{run_id}")
46+
created_at: datetime = datetime.datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S")
47+
st.sidebar.markdown(body=f"[`Open on WandB`]({run.url})")
48+
49+
timestamp: str = created_at.strftime("%Y%m%d_%H%M%S")
50+
51+
matching_runs: List[Path] = [item for item in WANDB_DIR.iterdir() if item.is_dir() and item.name.endswith(run_id)]
52+
53+
if len(matching_runs) > 1:
54+
st.error(f"More than one run matching unique id {run_id}! Are you sure about that?")
55+
st.stop()
56+
57+
if len(matching_runs) == 1:
58+
return matching_runs[0]
59+
60+
only_checkpoint: bool = st_run_sel.checkbox(label="Download only the checkpoint?", value=True)
61+
if st_run_sel.button(label="Download"):
62+
run_dir: Path = WANDB_DIR / f"restored-{timestamp}-{run.id}" / "files"
63+
files = [file for file in run.files() if "checkpoint" in file.name or not only_checkpoint]
64+
if len(files) == 0:
65+
st.error(f"There is no file to download from this run! Check on WandB: {run.url}")
66+
for file in stqdm(files, desc="Downloading files..."):
67+
file.download(root=run_dir)
68+
return run_dir
69+
else:
70+
st.stop()
71+
72+
73+
def select_run_path(st_key: str, default_run_path: str):
74+
run_path: str = st_run_sel.text_input(
75+
label="Run path (entity/project/id):",
76+
value=default_run_path,
77+
key=f"run_path_select_{st_key}",
78+
)
79+
if not run_path:
80+
st.stop()
81+
tokens: List[str] = run_path.split("/")
82+
if len(tokens) != 3:
83+
st.error(f"This run path {run_path} doesn't look like a WandB run path! Are you sure about that?")
84+
st.stop()
85+
86+
return tokens
87+
88+
89+
def select_checkpoint(st_key: str = "MyAwesomeModel", default_run_path: str = ""):
90+
entity, project, run_id = select_run_path(st_key=st_key, default_run_path=default_run_path)
91+
92+
run_dir: Path = get_run_dir(entity=entity, project=project, run_id=run_id)
93+
94+
return local_checkpoint_selection(run_dir, st_key=st_key)
95+
96+
97+
def get_hydra_cfg(config_name: str = "default") -> omegaconf.DictConfig:
98+
"""Instantiate and return the hydra config -- streamlit and jupyter compatible.
99+
100+
Args:
101+
config_name: .yaml configuration name, without the extension
102+
103+
Returns:
104+
The desired omegaconf.DictConfig
105+
"""
106+
GlobalHydra.instance().clear()
107+
hydra.experimental.initialize_config_dir(config_dir=str(PROJECT_ROOT / "conf"))
108+
return compose(config_name=config_name)

0 commit comments

Comments
 (0)