Skip to content

Commit 1ca2cde

Browse files
authored
Merge pull request #12962 from jacquesqiao/cherry-pick-fix-auc-init
fix auc layer and add check for auc op (#12954)
2 parents 69fffae + 50d29fc commit 1ca2cde

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

python/paddle/fluid/layers/metric_op.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,14 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
119119
helper = LayerHelper("auc", **locals())
120120
auc_out = helper.create_tmp_variable(dtype="float64")
121121
# make tp, tn, fp, fn persistable, so that can accumulate all batches.
122-
tp = helper.create_global_variable(persistable=True, dtype='int64')
123-
tn = helper.create_global_variable(persistable=True, dtype='int64')
124-
fp = helper.create_global_variable(persistable=True, dtype='int64')
125-
fn = helper.create_global_variable(persistable=True, dtype='int64')
122+
tp = helper.create_global_variable(
123+
persistable=True, dtype='int64', shape=[num_thresholds])
124+
tn = helper.create_global_variable(
125+
persistable=True, dtype='int64', shape=[num_thresholds])
126+
fp = helper.create_global_variable(
127+
persistable=True, dtype='int64', shape=[num_thresholds])
128+
fn = helper.create_global_variable(
129+
persistable=True, dtype='int64', shape=[num_thresholds])
126130
for var in [tp, tn, fp, fn]:
127131
helper.set_variable_initializer(
128132
var, Constant(

0 commit comments

Comments
 (0)