Skip to content

Commit 12ba65c

Browse files
authored
Merge pull request #173 from openml/improv/more_extensions
Allow custom metrics to be reported in results
2 parents 8849ce3 + ba3969d commit 12ba65c

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

amlb/results.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .datautils import accuracy_score, confusion_matrix, f1_score, log_loss, balanced_accuracy_score, mean_absolute_error, mean_squared_error, mean_squared_log_error, r2_score, roc_auc_score, read_csv, write_csv, is_data_frame, to_data_frame
1818
from .resources import get as rget, config as rconfig, output_dirs
1919
from .utils import Namespace, backup_file, cached, datetime_iso, memoize, profile
20+
from frameworks.shared.callee import get_extension
2021

2122
log = logging.getLogger(__name__)
2223

@@ -323,6 +324,13 @@ def __init__(self, predictions_df, info=None):
323324
def evaluate(self, metric):
324325
if hasattr(self, metric):
325326
return getattr(self, metric)()
327+
else:
328+
# A metric may be defined twice, once for the automl system to use (e.g.
329+
# as a scikit-learn scorer), and once in the amlb-compatible format.
330+
# The amlb-compatible format is marked with a trailing underscore.
331+
custom_metric = get_extension(rconfig().extensions_files, f"{metric}_")
332+
if custom_metric is not None:
333+
return custom_metric(self)
326334
# raise ValueError("Metric {metric} is not supported for {type}.".format(metric=metric, type=self.type))
327335
log.warning("Metric %s is not supported for %s!", metric, self.type)
328336
return nan

0 commit comments

Comments
 (0)