Skip to content

Commit 866f746

Browse files
safooraySafoora Yousefi
andauthored
Userconf (#167)
Allow importing user provided pipeline configs from user provided modules, remove the necessity of every pipeline config to be part of Eureka --------- Co-authored-by: Safoora Yousefi <[email protected]>
1 parent 31d5fce commit 866f746

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

main.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,39 @@
22

33
import argparse
44
import logging
5+
import os
6+
import sys
57

68
from eureka_ml_insights import user_configs as configs
79
from eureka_ml_insights.configs import model_configs
810
from eureka_ml_insights.core import Pipeline
911

1012
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
1113

14+
15+
def import_from_path(module_path, class_name):
16+
"""
17+
Dynamically import a class from a module path.
18+
"""
19+
sys.path.append(os.path.dirname(os.path.abspath(module_path)))
20+
print(sys.path)
21+
import importlib.util
22+
23+
spec = importlib.util.spec_from_file_location("experiment_config", module_path)
24+
module = importlib.util.module_from_spec(spec)
25+
spec.loader.exec_module(module)
26+
# Get the experiment config class from the module
27+
if hasattr(module, class_name):
28+
return getattr(module, class_name)
29+
logging.info(f"Using experiment config class {class_name} from {module_path}.")
30+
else:
31+
raise ValueError(f"Experiment config class {class_name} not found in {module_path}.")
32+
33+
1234
if __name__ == "__main__":
1335
parser = argparse.ArgumentParser(description="Run the pipeline for the specified experiment config class name.")
1436
parser.add_argument("--exp_config", type=str, help="The name of the experiment config class to run.", required=True)
37+
parser.add_argument("--exp_config_path", type=str, help="Path to the experiment config file.", default=None)
1538
parser.add_argument(
1639
"--model_config", type=str, nargs="?", help="The name of the model config to use.", default=None
1740
)
@@ -90,7 +113,11 @@
90113
if experiment_config_class in dir(configs):
91114
experiment_config_class = getattr(configs, experiment_config_class)
92115
else:
93-
raise ValueError(f"Experiment config class {experiment_config_class} not found.")
116+
# If the experiment_config_class is not found in the configs module, try to import it from args.exp_config_path.
117+
if args.exp_config_path:
118+
experiment_config_class = import_from_path(args.exp_config_path, args.exp_config)
119+
else:
120+
raise ValueError(f"Experiment config class {args.exp_config} not found.")
94121
pipeline_config = experiment_config_class(exp_logdir=args.exp_logdir, **init_args).pipeline_config
95122
logging.info(f"Saving experiment logs in {pipeline_config.log_dir}.")
96123
pipeline = Pipeline(pipeline_config.component_configs, pipeline_config.log_dir)

0 commit comments

Comments
 (0)