Skip to content

Commit 3ae88f1

Browse files
committed
fix multi-label AP and multi-label AUC metrics
1 parent 7aaa7db commit 3ae88f1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

libauc/metrics/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def auc_roc_score(y_true, y_pred, reduction='mean', **kwargs):
88
r"""Evaluation function of AUROC"""
99
y_true = check_array_type(y_true)
1010
y_pred = check_array_type(y_pred)
11-
num_labels = y_true.shape[-1] if len(y_true) == 2 else 1
11+
num_labels = y_true.shape[-1] if len(y_true.shape) == 2 else 1
1212
y_true = check_array_shape(y_true, (-1, num_labels))
1313
y_pred = check_array_shape(y_pred, (-1, num_labels))
1414
assert reduction in ['mean', None, 'None'], 'Input is not valid!'
@@ -31,7 +31,7 @@ def auc_prc_score(y_true, y_pred, reduction='mean', **kwargs):
3131
r"""Evaluation function of AUPRC"""
3232
y_true = check_array_type(y_true)
3333
y_pred = check_array_type(y_pred)
34-
num_labels = y_true.shape[-1] if len(y_true) == 2 else 1
34+
num_labels = y_true.shape[-1] if len(y_true.shape) == 2 else 1
3535
y_true = check_array_shape(y_true, (-1, num_labels))
3636
y_pred = check_array_shape(y_pred, (-1, num_labels))
3737
if y_pred.shape[-1] != 1 and len(y_pred.shape)>1:

0 commit comments

Comments
 (0)