22"""
33
44"""
5+ import inspect
56import math
67import os
78import pickle
1819
1920_MIN_BATCH_SIZE = 100000
2021
22+ _DEFAULT_RECALL_OPTIONS = {}
23+ if 'zero_division' in inspect .signature (sk_metrics .recall_score ).parameters .keys ():
24+ _DEFAULT_RECALL_OPTIONS ['zero_division' ] = 0.0
25+
2126
2227def _task_to_average (task ):
2328 if task == const .TASK_MULTICLASS :
@@ -37,10 +42,15 @@ def calc_score(y_true, y_preds, y_proba=None, metrics=('accuracy',), task=const.
3742 if len (y_preds .shape ) == 2 and y_preds .shape [- 1 ] == 1 :
3843 y_preds = y_preds .reshape (- 1 )
3944
45+ recall_options = _DEFAULT_RECALL_OPTIONS .copy ()
46+
4047 if average is None :
4148 average = _task_to_average (task )
49+ recall_options ['average' ] = average
50+
51+ if classes is not None :
52+ recall_options ['labels' ] = classes
4253
43- recall_options = dict (average = average , labels = classes )
4454 if pos_label is not None :
4555 recall_options ['pos_label' ] = pos_label
4656
@@ -112,7 +122,8 @@ def metric_to_scoring(metric, task=const.TASK_BINARY, pos_label=None):
112122 raise ValueError (f'Not found matching scoring for { metric } ' )
113123
114124 if metric_lower in metric2fn .keys ():
115- options = dict (average = _task_to_average (task ))
125+ options = _DEFAULT_RECALL_OPTIONS .copy ()
126+ options ['average' ] = _task_to_average (task )
116127 if pos_label is not None :
117128 options ['pos_label' ] = pos_label
118129 scoring = sk_metrics .make_scorer (metric2fn [metric_lower ], ** options )
0 commit comments