Skip to content

Commit ba8be70

Browse files
lucmosFlegyas
andauthored
Move ui_utils entirely to nn-core (#11)
* Move ui_utils entirely to nn-core Co-authored-by: Valentino Maiorca <[email protected]>
1 parent 90a1737 commit ba8be70

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

src/nn_core/ui.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import datetime
2+
import operator
3+
from pathlib import Path
4+
from typing import List
5+
6+
import hydra
7+
import omegaconf
8+
import streamlit as st
9+
import wandb
10+
from hydra.core.global_hydra import GlobalHydra
11+
from hydra.experimental import compose
12+
from stqdm import stqdm
13+
14+
from nn_core.common import PROJECT_ROOT
15+
16+
WANDB_DIR: Path = PROJECT_ROOT / "wandb"
17+
WANDB_DIR.mkdir(exist_ok=True, parents=True)
18+
19+
st_run_sel = st.sidebar
20+
21+
22+
def local_checkpoint_selection(run_dir: Path, st_key: str) -> Path:
23+
checkpoint_paths: List[Path] = list(run_dir.rglob("checkpoints/*"))
24+
if len(checkpoint_paths) == 0:
25+
st.error(f"There's no checkpoint under {run_dir}! Are you sure the restore was successful?")
26+
st.stop()
27+
checkpoint_path: Path = st_run_sel.selectbox(
28+
label="Select a checkpoint",
29+
index=0,
30+
options=checkpoint_paths,
31+
format_func=operator.attrgetter("name"),
32+
key=f"checkpoint_select_{st_key}",
33+
)
34+
35+
return checkpoint_path
36+
37+
38+
def get_run_dir(entity: str, project: str, run_id: str) -> Path:
39+
"""Get run directory.
40+
41+
:param run_path: "entity/project/run_id"
42+
:return:
43+
"""
44+
api = wandb.Api()
45+
run = api.run(path=f"{entity}/{project}/{run_id}")
46+
created_at: datetime = datetime.datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S")
47+
st.sidebar.markdown(body=f"[`Open on WandB`]({run.url})")
48+
49+
timestamp: str = created_at.strftime("%Y%m%d_%H%M%S")
50+
51+
matching_runs: List[Path] = [item for item in WANDB_DIR.iterdir() if item.is_dir() and item.name.endswith(run_id)]
52+
53+
if len(matching_runs) > 1:
54+
st.error(f"More than one run matching unique id {run_id}! Are you sure about that?")
55+
st.stop()
56+
57+
if len(matching_runs) == 1:
58+
return matching_runs[0]
59+
60+
only_checkpoint: bool = st_run_sel.checkbox(label="Download only the checkpoint?", value=True)
61+
if st_run_sel.button(label="Download"):
62+
run_dir: Path = WANDB_DIR / f"restored-{timestamp}-{run.id}" / "files"
63+
files = [file for file in run.files() if "checkpoint" in file.name or not only_checkpoint]
64+
if len(files) == 0:
65+
st.error(f"There is no file to download from this run! Check on WandB: {run.url}")
66+
for file in stqdm(files, desc="Downloading files..."):
67+
file.download(root=run_dir)
68+
return run_dir
69+
else:
70+
st.stop()
71+
72+
73+
def select_run_path(st_key: str, default_run_path: str):
74+
run_path: str = st_run_sel.text_input(
75+
label="Run path (entity/project/id):",
76+
value=default_run_path,
77+
key=f"run_path_select_{st_key}",
78+
)
79+
if not run_path:
80+
st.stop()
81+
tokens: List[str] = run_path.split("/")
82+
if len(tokens) != 3:
83+
st.error(f"This run path {run_path} doesn't look like a WandB run path! Are you sure about that?")
84+
st.stop()
85+
86+
return tokens
87+
88+
89+
def select_checkpoint(st_key: str = "MyAwesomeModel", default_run_path: str = ""):
90+
entity, project, run_id = select_run_path(st_key=st_key, default_run_path=default_run_path)
91+
92+
run_dir: Path = get_run_dir(entity=entity, project=project, run_id=run_id)
93+
94+
return local_checkpoint_selection(run_dir, st_key=st_key)
95+
96+
97+
def get_hydra_cfg(config_name: str = "default") -> omegaconf.DictConfig:
98+
"""Instantiate and return the hydra config -- streamlit and jupyter compatible.
99+
100+
Args:
101+
config_name: .yaml configuration name, without the extension
102+
103+
Returns:
104+
The desired omegaconf.DictConfig
105+
"""
106+
GlobalHydra.instance().clear()
107+
hydra.experimental.initialize_config_dir(config_dir=str(PROJECT_ROOT / "conf"))
108+
return compose(config_name=config_name)

0 commit comments

Comments
 (0)