We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fc50800 commit d54e51dCopy full SHA for d54e51d
python/paddle/fluid/metrics.py
@@ -580,10 +580,10 @@ def __init__(self, name, curve='ROC', num_thresholds=200):
580
self.tn_list = np.zeros((num_thresholds, ))
581
self.fp_list = np.zeros((num_thresholds, ))
582
583
- def update(self, predictions, labels):
+ def update(self, preds, labels):
584
if not _is_numpy_(labels):
585
raise ValueError("The 'labels' must be a numpy ndarray.")
586
- if not _is_numpy_(predictions):
+ if not _is_numpy_(preds):
587
raise ValueError("The 'predictions' must be a numpy ndarray.")
588
589
kepsilon = 1e-7 # to account for floating point imprecisions
@@ -596,12 +596,12 @@ def update(self, predictions, labels):
596
tp, fn, tn, fp = 0, 0, 0, 0
597
for i, lbl in enumerate(labels):
598
if lbl:
599
- if predictions[i, 1] >= thresh:
+ if preds[i, 1] >= thresh:
600
tp += 1
601
else:
602
fn += 1
603
604
605
fp += 1
606
607
tn += 1
0 commit comments