11import logging
22import os
3- from typing import Optional
3+ from typing import List , Optional
44
55import dotenv
6+ import numpy as np
7+ from hydra .core .hydra_config import HydraConfig
8+ from omegaconf import DictConfig
9+ from pytorch_lightning import seed_everything
10+ from rich .prompt import Prompt
611
7- logger = logging .getLogger (__name__ )
12+ pylogger = logging .getLogger (__name__ )
813
914
1015def get_env (env_name : str , default : Optional [str ] = None ) -> str :
@@ -19,13 +24,17 @@ def get_env(env_name: str, default: Optional[str] = None) -> str:
1924 """
2025 if env_name not in os .environ :
2126 if default is None :
22- raise KeyError (f"{ env_name } not defined and no default value is present!" )
27+ message = f"{ env_name } not defined and no default value is present!"
28+ pylogger .error (message )
29+ raise KeyError (message )
2330 return default
2431
2532 env_value : str = os .environ [env_name ]
2633 if not env_value :
2734 if default is None :
28- raise ValueError (f"{ env_name } has yet to be configured and no default value is present!" )
35+ message = f"{ env_name } has yet to be configured and no default value is present!"
36+ pylogger .error (message )
37+ raise ValueError (message )
2938 return default
3039
3140 return env_value
@@ -42,3 +51,31 @@ def load_envs(env_file: Optional[str] = None) -> None:
4251 it searches for a `.env` file in the project.
4352 """
4453 dotenv .load_dotenv (dotenv_path = env_file , override = True )
54+
55+
56+ def enforce_tags (tags : Optional [List [str ]]) -> List [str ]:
57+ if tags is None :
58+ if "id" in HydraConfig ().cfg .hydra .job :
59+ # We are in multi-run setting (either via a sweep or a scheduler)
60+ message : str = "You need to specify 'core.tags' in a multi-run setting!"
61+ pylogger .error (message )
62+ raise ValueError (message )
63+
64+ pylogger .warning ("No tags provided, asking for tags..." )
65+ tags = Prompt .ask ("Enter a list of comma separated tags" , default = "develop" )
66+ tags = [x .strip () for x in tags .split ("," )]
67+
68+ pylogger .info (f"Tags: { tags if tags is not None else []} " )
69+ return tags
70+
71+
72+ def seed_index_everything (train_cfg : DictConfig ) -> None :
73+ if "seed_index" in train_cfg and train_cfg .seed_index is not None :
74+ seed_index = train_cfg .seed_index
75+ seed_everything (42 )
76+ seeds = np .random .randint (np .iinfo (np .int32 ).max , size = max (42 , seed_index + 1 ))
77+ seed = seeds [seed_index ]
78+ seed_everything (seed )
79+ pylogger .info (f"Setting seed { seed } from seeds[{ seed_index } ]" )
80+ else :
81+ pylogger .warning ("The seed has not been set! The reproducibility is not guaranteed." )
0 commit comments