@@ -325,14 +325,14 @@ class Auc(MetricBase):
325
325
"""
326
326
327
327
def __init__ (self , name , curve = 'ROC' , num_thresholds = 200 ):
328
- super (MetricBase , self ).__init__ (name , curve , num_thresholds )
328
+ super (Auc , self ).__init__ (name = name )
329
329
self ._curve = curve
330
330
self ._num_thresholds = num_thresholds
331
331
self ._epsilon = 1e-6
332
- self .tp_list = np .ndarray ((num_thresholds , ))
333
- self .fn_list = np .ndarray ((num_thresholds , ))
334
- self .tn_list = np .ndarray ((num_thresholds , ))
335
- self .fp_list = np .ndarray ((num_thresholds , ))
332
+ self .tp_list = np .zeros ((num_thresholds , ))
333
+ self .fn_list = np .zeros ((num_thresholds , ))
334
+ self .tn_list = np .zeros ((num_thresholds , ))
335
+ self .fp_list = np .zeros ((num_thresholds , ))
336
336
337
337
def update (self , labels , predictions , axis = 1 ):
338
338
if not _is_numpy_ (labels ):
@@ -350,12 +350,12 @@ def update(self, labels, predictions, axis=1):
350
350
tp , fn , tn , fp = 0 , 0 , 0 , 0
351
351
for i , lbl in enumerate (labels ):
352
352
if lbl :
353
- if predictions [i , 0 ] >= thresh :
353
+ if predictions [i , 1 ] >= thresh :
354
354
tp += 1
355
355
else :
356
356
fn += 1
357
357
else :
358
- if predictions [i , 0 ] >= thresh :
358
+ if predictions [i , 1 ] >= thresh :
359
359
fp += 1
360
360
else :
361
361
tn += 1
0 commit comments