@@ -8,7 +8,7 @@ def auc_roc_score(y_true, y_pred, reduction='mean', **kwargs):
8
8
r"""Evaluation function of AUROC"""
9
9
y_true = check_array_type (y_true )
10
10
y_pred = check_array_type (y_pred )
11
- num_labels = y_true .shape [- 1 ] if len (y_true ) == 2 else 1
11
+ num_labels = y_true .shape [- 1 ] if len (y_true . shape ) == 2 else 1
12
12
y_true = check_array_shape (y_true , (- 1 , num_labels ))
13
13
y_pred = check_array_shape (y_pred , (- 1 , num_labels ))
14
14
assert reduction in ['mean' , None , 'None' ], 'Input is not valid!'
@@ -31,7 +31,7 @@ def auc_prc_score(y_true, y_pred, reduction='mean', **kwargs):
31
31
r"""Evaluation function of AUPRC"""
32
32
y_true = check_array_type (y_true )
33
33
y_pred = check_array_type (y_pred )
34
- num_labels = y_true .shape [- 1 ] if len (y_true ) == 2 else 1
34
+ num_labels = y_true .shape [- 1 ] if len (y_true . shape ) == 2 else 1
35
35
y_true = check_array_shape (y_true , (- 1 , num_labels ))
36
36
y_pred = check_array_shape (y_pred , (- 1 , num_labels ))
37
37
if y_pred .shape [- 1 ] != 1 and len (y_pred .shape )> 1 :
0 commit comments