Skip to content

Commit 79f9fb1

Browse files
authored
Move functions from template to core (#9)
* Move parse_restore to nn-core * Move enforce_tags to nn-core * Move seed_index logic to nn-core
1 parent 3f7dabe commit 79f9fb1

File tree

2 files changed

+94
-5
lines changed

2 files changed

+94
-5
lines changed

src/nn_core/common/utils.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import logging
22
import os
3-
from typing import Optional
3+
from typing import List, Optional
44

55
import dotenv
6+
import numpy as np
7+
from hydra.core.hydra_config import HydraConfig
8+
from omegaconf import DictConfig
9+
from pytorch_lightning import seed_everything
10+
from rich.prompt import Prompt
611

7-
logger = logging.getLogger(__name__)
12+
pylogger = logging.getLogger(__name__)
813

914

1015
def get_env(env_name: str, default: Optional[str] = None) -> str:
@@ -19,13 +24,17 @@ def get_env(env_name: str, default: Optional[str] = None) -> str:
1924
"""
2025
if env_name not in os.environ:
2126
if default is None:
22-
raise KeyError(f"{env_name} not defined and no default value is present!")
27+
message = f"{env_name} not defined and no default value is present!"
28+
pylogger.error(message)
29+
raise KeyError(message)
2330
return default
2431

2532
env_value: str = os.environ[env_name]
2633
if not env_value:
2734
if default is None:
28-
raise ValueError(f"{env_name} has yet to be configured and no default value is present!")
35+
message = f"{env_name} has yet to be configured and no default value is present!"
36+
pylogger.error(message)
37+
raise ValueError(message)
2938
return default
3039

3140
return env_value
@@ -42,3 +51,31 @@ def load_envs(env_file: Optional[str] = None) -> None:
4251
it searches for a `.env` file in the project.
4352
"""
4453
dotenv.load_dotenv(dotenv_path=env_file, override=True)
54+
55+
56+
def enforce_tags(tags: Optional[List[str]]) -> List[str]:
57+
if tags is None:
58+
if "id" in HydraConfig().cfg.hydra.job:
59+
# We are in multi-run setting (either via a sweep or a scheduler)
60+
message: str = "You need to specify 'core.tags' in a multi-run setting!"
61+
pylogger.error(message)
62+
raise ValueError(message)
63+
64+
pylogger.warning("No tags provided, asking for tags...")
65+
tags = Prompt.ask("Enter a list of comma separated tags", default="develop")
66+
tags = [x.strip() for x in tags.split(",")]
67+
68+
pylogger.info(f"Tags: {tags if tags is not None else []}")
69+
return tags
70+
71+
72+
def seed_index_everything(train_cfg: DictConfig) -> None:
73+
if "seed_index" in train_cfg and train_cfg.seed_index is not None:
74+
seed_index = train_cfg.seed_index
75+
seed_everything(42)
76+
seeds = np.random.randint(np.iinfo(np.int32).max, size=max(42, seed_index + 1))
77+
seed = seeds[seed_index]
78+
seed_everything(seed)
79+
pylogger.info(f"Setting seed {seed} from seeds[{seed_index}]")
80+
else:
81+
pylogger.warning("The seed has not been set! The reproducibility is not guaranteed.")

src/nn_core/resume.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,33 @@
1+
import logging
12
import re
3+
from operator import xor
24
from pathlib import Path
3-
from typing import Optional
5+
from typing import Optional, Tuple
46

57
import torch
68
import wandb
9+
from omegaconf import DictConfig
710
from wandb.apis.public import Run
811

12+
pylogger = logging.getLogger(__name__)
13+
914
RUN_PATH_PATTERN = re.compile(r"^([^/]+)/([^/]+)/([^/]+)$")
1015

16+
RESUME_MODES = {
17+
"continue": {
18+
"restore_model": True,
19+
"restore_run": True,
20+
},
21+
"hotstart": {
22+
"restore_model": True,
23+
"restore_run": False,
24+
},
25+
None: {
26+
"restore_model": False,
27+
"restore_run": False,
28+
},
29+
}
30+
1131

1232
def resolve_ckpt(ckpt_or_run_path: str) -> str:
1333
"""Resolve the run path or ckpt to a checkpoint.
@@ -61,3 +81,35 @@ def resolve_run_version(ckpt_or_run_path: Optional[str] = None, run_path: Option
6181
if run_path is None:
6282
run_path = resolve_run_path(ckpt_or_run_path)
6383
return RUN_PATH_PATTERN.match(run_path).group(3)
84+
85+
86+
def parse_restore(restore_cfg: DictConfig) -> Tuple[Optional[str], Optional[str]]:
87+
ckpt_or_run_path = restore_cfg.ckpt_or_run_path
88+
resume_mode = restore_cfg.mode
89+
90+
resume_ckpt_path = None
91+
resume_run_version = None
92+
93+
if xor(bool(ckpt_or_run_path), bool(resume_mode)):
94+
pylogger.warning(f"Inconsistent resume modality {resume_mode} and checkpoint path '{ckpt_or_run_path}'")
95+
96+
if resume_mode not in RESUME_MODES:
97+
message = f"Unsupported resume mode {resume_mode}. Available resume modes are: {RESUME_MODES}"
98+
pylogger.error(message)
99+
raise ValueError(message)
100+
101+
flags = RESUME_MODES[resume_mode]
102+
restore_model = flags["restore_model"]
103+
restore_run = flags["restore_run"]
104+
105+
if ckpt_or_run_path is not None:
106+
if restore_model:
107+
resume_ckpt_path = resolve_ckpt(ckpt_or_run_path)
108+
pylogger.info(f"Resume training from: '{resume_ckpt_path}'")
109+
110+
if restore_run:
111+
run_path = resolve_run_path(ckpt_or_run_path)
112+
resume_run_version = resolve_run_version(run_path=run_path)
113+
pylogger.info(f"Resume logging to: '{run_path}'")
114+
115+
return resume_ckpt_path, resume_run_version

0 commit comments

Comments
 (0)