Skip to content

Commit 1d4d8de

Browse files
authored
Merge pull request #11574 from jacquesqiao/fix-auc
fix auc
2 parents a009272 + 4aa5da0 commit 1d4d8de

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

python/paddle/fluid/metrics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,14 +325,14 @@ class Auc(MetricBase):
325325
"""
326326

327327
def __init__(self, name, curve='ROC', num_thresholds=200):
328-
super(MetricBase, self).__init__(name, curve, num_thresholds)
328+
super(Auc, self).__init__(name=name)
329329
self._curve = curve
330330
self._num_thresholds = num_thresholds
331331
self._epsilon = 1e-6
332-
self.tp_list = np.ndarray((num_thresholds, ))
333-
self.fn_list = np.ndarray((num_thresholds, ))
334-
self.tn_list = np.ndarray((num_thresholds, ))
335-
self.fp_list = np.ndarray((num_thresholds, ))
332+
self.tp_list = np.zeros((num_thresholds, ))
333+
self.fn_list = np.zeros((num_thresholds, ))
334+
self.tn_list = np.zeros((num_thresholds, ))
335+
self.fp_list = np.zeros((num_thresholds, ))
336336

337337
def update(self, labels, predictions, axis=1):
338338
if not _is_numpy_(labels):
@@ -350,12 +350,12 @@ def update(self, labels, predictions, axis=1):
350350
tp, fn, tn, fp = 0, 0, 0, 0
351351
for i, lbl in enumerate(labels):
352352
if lbl:
353-
if predictions[i, 0] >= thresh:
353+
if predictions[i, 1] >= thresh:
354354
tp += 1
355355
else:
356356
fn += 1
357357
else:
358-
if predictions[i, 0] >= thresh:
358+
if predictions[i, 1] >= thresh:
359359
fp += 1
360360
else:
361361
tn += 1

0 commit comments

Comments
 (0)