Skip to content

Commit 7ce3882

Browse files
committed
fix auc layer and add check for auc op (#12954)
* fix auc layer and add check for auc op * use input to check if states are inited * optimize code
1 parent 198ee1a commit 7ce3882

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

paddle/fluid/operators/auc_op.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ class AucKernel : public framework::OpKernel<T> {
6060
const T* inference_data = predict->data<T>();
6161
const auto* label_data = label->data<int64_t>();
6262

63+
// check if states are inited.
64+
auto* tp_in = ctx.Input<Tensor>("TP");
65+
auto* fp_in = ctx.Input<Tensor>("FP");
66+
auto* tn_in = ctx.Input<Tensor>("TN");
67+
auto* fn_in = ctx.Input<Tensor>("FN");
68+
PADDLE_ENFORCE(tp_in->IsInitialized(), "true_positive is not inited!");
69+
PADDLE_ENFORCE(fp_in->IsInitialized(), "false_negative is not inited!");
70+
PADDLE_ENFORCE(tn_in->IsInitialized(), "true_negative is not inited!");
71+
PADDLE_ENFORCE(fn_in->IsInitialized(), "false_positive is not inited!");
72+
PADDLE_ENFORCE_EQ(tp_in->numel(), num_thresholds, "");
73+
PADDLE_ENFORCE_EQ(fp_in->numel(), num_thresholds, "");
74+
PADDLE_ENFORCE_EQ(tn_in->numel(), num_thresholds, "");
75+
PADDLE_ENFORCE_EQ(fn_in->numel(), num_thresholds, "");
76+
6377
auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
6478
auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
6579
auto* tn_data = true_negative->mutable_data<int64_t>(ctx.GetPlace());

python/paddle/fluid/layers/metric_op.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,14 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
119119
helper = LayerHelper("auc", **locals())
120120
auc_out = helper.create_tmp_variable(dtype="float64")
121121
# make tp, tn, fp, fn persistable, so that can accumulate all batches.
122-
tp = helper.create_global_variable(persistable=True, dtype='int64')
123-
tn = helper.create_global_variable(persistable=True, dtype='int64')
124-
fp = helper.create_global_variable(persistable=True, dtype='int64')
125-
fn = helper.create_global_variable(persistable=True, dtype='int64')
122+
tp = helper.create_global_variable(
123+
persistable=True, dtype='int64', shape=[num_thresholds])
124+
tn = helper.create_global_variable(
125+
persistable=True, dtype='int64', shape=[num_thresholds])
126+
fp = helper.create_global_variable(
127+
persistable=True, dtype='int64', shape=[num_thresholds])
128+
fn = helper.create_global_variable(
129+
persistable=True, dtype='int64', shape=[num_thresholds])
126130
for var in [tp, tn, fp, fn]:
127131
helper.set_variable_initializer(
128132
var, Constant(

0 commit comments

Comments
 (0)