Skip to content

Commit cdf41b9

Browse files
BUG: Fix metric bug in Utilities (#33)
* BUG: Correct metric list usage for evaluation in Utilities * BUG: Fix incorrect metrics import in Utilities
1 parent 74fd97c commit cdf41b9

File tree

1 file changed

+9
-35
lines changed

1 file changed

+9
-35
lines changed

orca_python/utilities/utilities.py

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
import pandas as pd
1515
from pkg_resources import get_distribution, parse_version
1616
from sklearn import preprocessing
17-
from sklearn.metrics import make_scorer
1817
from sklearn.model_selection import GridSearchCV, StratifiedKFold
1918

19+
from orca_python.metrics import compute_metric, load_metric_as_scorer
2020
from orca_python.results import Results
2121

2222

@@ -169,30 +169,19 @@ def run_experiment(self):
169169
train_metrics = OrderedDict()
170170
test_metrics = OrderedDict()
171171
for metric_name in self.general_conf["metrics"]:
172-
173-
try:
174-
# Loading metric from file
175-
module = __import__("orca_python").metrics
176-
metric = getattr(
177-
module, self.general_conf["cv_metric"].lower().strip()
178-
)
179-
180-
except AttributeError:
181-
raise AttributeError(
182-
"No metric named '%s'" % metric_name.strip().lower()
183-
)
184-
185172
# Get train scores
186-
train_score = metric(
187-
partition["train_outputs"], train_predicted_y
173+
train_score = compute_metric(
174+
metric_name,
175+
partition["train_outputs"],
176+
train_predicted_y,
188177
)
189178
train_metrics[metric_name.strip() + "_train"] = train_score
190179

191180
# Get test scores
192181
test_metrics[metric_name.strip() + "_test"] = np.nan
193182
if "test_outputs" in partition:
194-
test_score = metric(
195-
partition["test_outputs"], test_predicted_y
183+
test_score = compute_metric(
184+
metric_name, partition["test_outputs"], test_predicted_y
196185
)
197186
test_metrics[metric_name.strip() + "_test"] = test_score
198187

@@ -536,23 +525,8 @@ def _get_optimal_estimator(
536525
optimal.refit_time_ = elapsed
537526
return optimal
538527

539-
try:
540-
module = __import__("orca_python").metrics
541-
metric = getattr(module, self.general_conf["cv_metric"].lower().strip())
542-
543-
except AttributeError:
544-
545-
if not isinstance(self.general_conf["cv_metric"], str):
546-
raise AttributeError("cv_metric must be string")
547-
548-
raise AttributeError(
549-
"No metric named '%s' implemented"
550-
% self.general_conf["cv_metric"].strip().lower()
551-
)
552-
553-
# Making custom metrics compatible with sklearn
554-
gib = module.greater_is_better(self.general_conf["cv_metric"].lower().strip())
555-
scoring_function = make_scorer(metric, greater_is_better=gib)
528+
metric_name = self.general_conf["cv_metric"].strip().lower()
529+
scoring_function = load_metric_as_scorer(metric_name)
556530

557531
# Creating object to split train data for cross-validation
558532
# This will make GridSearch have a pseudo-random beheaviour

0 commit comments

Comments
 (0)