|
14 | 14 | import pandas as pd |
15 | 15 | from pkg_resources import get_distribution, parse_version |
16 | 16 | from sklearn import preprocessing |
17 | | -from sklearn.metrics import make_scorer |
18 | 17 | from sklearn.model_selection import GridSearchCV, StratifiedKFold |
19 | 18 |
|
| 19 | +from orca_python.metrics import compute_metric, load_metric_as_scorer |
20 | 20 | from orca_python.results import Results |
21 | 21 |
|
22 | 22 |
|
@@ -169,30 +169,19 @@ def run_experiment(self): |
169 | 169 | train_metrics = OrderedDict() |
170 | 170 | test_metrics = OrderedDict() |
171 | 171 | for metric_name in self.general_conf["metrics"]: |
172 | | - |
173 | | - try: |
174 | | - # Loading metric from file |
175 | | - module = __import__("orca_python").metrics |
176 | | - metric = getattr( |
177 | | - module, self.general_conf["cv_metric"].lower().strip() |
178 | | - ) |
179 | | - |
180 | | - except AttributeError: |
181 | | - raise AttributeError( |
182 | | - "No metric named '%s'" % metric_name.strip().lower() |
183 | | - ) |
184 | | - |
185 | 172 | # Get train scores |
186 | | - train_score = metric( |
187 | | - partition["train_outputs"], train_predicted_y |
| 173 | + train_score = compute_metric( |
| 174 | + metric_name, |
| 175 | + partition["train_outputs"], |
| 176 | + train_predicted_y, |
188 | 177 | ) |
189 | 178 | train_metrics[metric_name.strip() + "_train"] = train_score |
190 | 179 |
|
191 | 180 | # Get test scores |
192 | 181 | test_metrics[metric_name.strip() + "_test"] = np.nan |
193 | 182 | if "test_outputs" in partition: |
194 | | - test_score = metric( |
195 | | - partition["test_outputs"], test_predicted_y |
| 183 | + test_score = compute_metric( |
| 184 | + metric_name, partition["test_outputs"], test_predicted_y |
196 | 185 | ) |
197 | 186 | test_metrics[metric_name.strip() + "_test"] = test_score |
198 | 187 |
|
@@ -536,23 +525,8 @@ def _get_optimal_estimator( |
536 | 525 | optimal.refit_time_ = elapsed |
537 | 526 | return optimal |
538 | 527 |
|
539 | | - try: |
540 | | - module = __import__("orca_python").metrics |
541 | | - metric = getattr(module, self.general_conf["cv_metric"].lower().strip()) |
542 | | - |
543 | | - except AttributeError: |
544 | | - |
545 | | - if not isinstance(self.general_conf["cv_metric"], str): |
546 | | - raise AttributeError("cv_metric must be string") |
547 | | - |
548 | | - raise AttributeError( |
549 | | - "No metric named '%s' implemented" |
550 | | - % self.general_conf["cv_metric"].strip().lower() |
551 | | - ) |
552 | | - |
553 | | - # Making custom metrics compatible with sklearn |
554 | | - gib = module.greater_is_better(self.general_conf["cv_metric"].lower().strip()) |
555 | | - scoring_function = make_scorer(metric, greater_is_better=gib) |
| 528 | + metric_name = self.general_conf["cv_metric"].strip().lower() |
| 529 | + scoring_function = load_metric_as_scorer(metric_name) |
556 | 530 |
|
557 | 531 | # Creating object to split train data for cross-validation |
558 | 532 | # This will make GridSearch have a pseudo-random beheaviour |
|
0 commit comments