|
2 | 2 |
|
3 | 3 | import argparse |
4 | 4 | import logging |
| 5 | +import os |
| 6 | +import sys |
5 | 7 |
|
6 | 8 | from eureka_ml_insights import user_configs as configs |
7 | 9 | from eureka_ml_insights.configs import model_configs |
8 | 10 | from eureka_ml_insights.core import Pipeline |
9 | 11 |
|
10 | 12 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
11 | 13 |
|
| 14 | + |
| 15 | +def import_from_path(module_path, class_name): |
| 16 | + """ |
| 17 | + Dynamically import a class from a module path. |
| 18 | + """ |
| 19 | + sys.path.append(os.path.dirname(os.path.abspath(module_path))) |
| 20 | + print(sys.path) |
| 21 | + import importlib.util |
| 22 | + |
| 23 | + spec = importlib.util.spec_from_file_location("experiment_config", module_path) |
| 24 | + module = importlib.util.module_from_spec(spec) |
| 25 | + spec.loader.exec_module(module) |
| 26 | + # Get the experiment config class from the module |
| 27 | + if hasattr(module, class_name): |
| 28 | + return getattr(module, class_name) |
| 29 | + logging.info(f"Using experiment config class {class_name} from {module_path}.") |
| 30 | + else: |
| 31 | + raise ValueError(f"Experiment config class {class_name} not found in {module_path}.") |
| 32 | + |
| 33 | + |
12 | 34 | if __name__ == "__main__": |
13 | 35 | parser = argparse.ArgumentParser(description="Run the pipeline for the specified experiment config class name.") |
14 | 36 | parser.add_argument("--exp_config", type=str, help="The name of the experiment config class to run.", required=True) |
| 37 | + parser.add_argument("--exp_config_path", type=str, help="Path to the experiment config file.", default=None) |
15 | 38 | parser.add_argument( |
16 | 39 | "--model_config", type=str, nargs="?", help="The name of the model config to use.", default=None |
17 | 40 | ) |
|
90 | 113 | if experiment_config_class in dir(configs): |
91 | 114 | experiment_config_class = getattr(configs, experiment_config_class) |
92 | 115 | else: |
93 | | - raise ValueError(f"Experiment config class {experiment_config_class} not found.") |
| 116 | + # If the experiment_config_class is not found in the configs module, try to import it from args.exp_config_path. |
| 117 | + if args.exp_config_path: |
| 118 | + experiment_config_class = import_from_path(args.exp_config_path, args.exp_config) |
| 119 | + else: |
| 120 | + raise ValueError(f"Experiment config class {args.exp_config} not found.") |
94 | 121 | pipeline_config = experiment_config_class(exp_logdir=args.exp_logdir, **init_args).pipeline_config |
95 | 122 | logging.info(f"Saving experiment logs in {pipeline_config.log_dir}.") |
96 | 123 | pipeline = Pipeline(pipeline_config.component_configs, pipeline_config.log_dir) |
|
0 commit comments