|
10 | 10 | import os |
11 | 11 | import configparser as cfg |
12 | 12 | from datetime import date |
13 | | -from utils.configloader import EXP_NAME |
| 13 | +from utils.configloader import EXP_NAME, EXP_ORIGIN |
14 | 14 |
|
15 | 15 |
|
16 | 16 |
|
@@ -95,19 +95,35 @@ def get_process_settings(process_name, parameter_dict): |
95 | 95 |
|
96 | 96 |
|
97 | 97 | def setup_experiment(): |
98 | | - config = cfg.ConfigParser() |
99 | | - path = os.path.join(os.path.dirname(__file__), '..', 'configs', f'{EXP_NAME}.ini') |
100 | | - with open(path) as file: |
101 | | - config.read_file(file) |
102 | 98 |
|
103 | | - experiment_name = config['EXPERIMENT']['BASE'] |
104 | | - import importlib |
105 | | - mod = importlib.import_module('experiments.base.experiments') |
106 | | - try: |
107 | | - experiment_class = getattr(mod, experiment_name) |
108 | | - experiment = experiment_class() |
109 | | - except Exception: |
110 | | - raise ValueError(f'Experiment: {experiment_name} not in base.experiments.py.') |
| 99 | + if EXP_ORIGIN.upper() == 'BASE': |
| 100 | + config = cfg.ConfigParser() |
| 101 | + path = os.path.join(os.path.dirname(__file__), '..', 'configs', f'{EXP_NAME}.ini') |
| 102 | + try: |
| 103 | + with open(path) as file: |
| 104 | + config.read_file(file) |
| 105 | + except FileNotFoundError: |
| 106 | + raise FileNotFoundError(f'{EXP_NAME}.ini was not found. Make sure it exists.') |
| 107 | + |
| 108 | + experiment_name = config['EXPERIMENT']['BASE'] |
| 109 | + import importlib |
| 110 | + mod = importlib.import_module('experiments.base.experiments') |
| 111 | + try: |
| 112 | + experiment_class = getattr(mod, experiment_name) |
| 113 | + experiment = experiment_class() |
| 114 | + except Exception: |
| 115 | + raise ValueError(f'Experiment: {experiment_name} not in base.experiments.py.') |
| 116 | + |
| 117 | + elif EXP_ORIGIN.upper() == 'CUSTOM': |
| 118 | + |
| 119 | + experiment_name = EXP_NAME |
| 120 | + import importlib |
| 121 | + mod = importlib.import_module('experiments.custom.experiments') |
| 122 | + try: |
| 123 | + experiment_class = getattr(mod, experiment_name) |
| 124 | + experiment = experiment_class() |
| 125 | + except Exception: |
| 126 | + raise ValueError(f'Experiment: {experiment_name} not in custom.experiments.py.') |
111 | 127 |
|
112 | 128 | return experiment |
113 | 129 |
|
|
0 commit comments