Skip to content

Commit 036b0fa

Browse files
committed
Update metrics
1 parent d4d195b commit 036b0fa

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

hypernets/tabular/metrics.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""
33
44
"""
5+
import inspect
56
import math
67
import os
78
import pickle
@@ -18,6 +19,10 @@
1819

1920
_MIN_BATCH_SIZE = 100000
2021

22+
_DEFAULT_RECALL_OPTIONS = {}
23+
if 'zero_division' in inspect.signature(sk_metrics.recall_score).parameters.keys():
24+
_DEFAULT_RECALL_OPTIONS['zero_division'] = 0.0
25+
2126

2227
def _task_to_average(task):
2328
if task == const.TASK_MULTICLASS:
@@ -37,10 +42,15 @@ def calc_score(y_true, y_preds, y_proba=None, metrics=('accuracy',), task=const.
3742
if len(y_preds.shape) == 2 and y_preds.shape[-1] == 1:
3843
y_preds = y_preds.reshape(-1)
3944

45+
recall_options = _DEFAULT_RECALL_OPTIONS.copy()
46+
4047
if average is None:
4148
average = _task_to_average(task)
49+
recall_options['average'] = average
50+
51+
if classes is not None:
52+
recall_options['labels'] = classes
4253

43-
recall_options = dict(average=average, labels=classes)
4454
if pos_label is not None:
4555
recall_options['pos_label'] = pos_label
4656

@@ -112,7 +122,8 @@ def metric_to_scoring(metric, task=const.TASK_BINARY, pos_label=None):
112122
raise ValueError(f'Not found matching scoring for {metric}')
113123

114124
if metric_lower in metric2fn.keys():
115-
options = dict(average=_task_to_average(task))
125+
options = _DEFAULT_RECALL_OPTIONS.copy()
126+
options['average'] = _task_to_average(task)
116127
if pos_label is not None:
117128
options['pos_label'] = pos_label
118129
scoring = sk_metrics.make_scorer(metric2fn[metric_lower], **options)

0 commit comments

Comments
 (0)