Skip to content

Commit 3c58b87

Browse files
authored
fix auc layer and add check for auc op (#12954)
* fix auc layer and add check for auc op * use input to check if states are inited * optimize code
1 parent c1488b1 commit 3c58b87

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

paddle/fluid/operators/auc_op.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ class AucKernel : public framework::OpKernel<T> {
6060
const T* inference_data = predict->data<T>();
6161
const auto* label_data = label->data<int64_t>();
6262

63+
// check if states are inited.
64+
auto* tp_in = ctx.Input<Tensor>("TP");
65+
auto* fp_in = ctx.Input<Tensor>("FP");
66+
auto* tn_in = ctx.Input<Tensor>("TN");
67+
auto* fn_in = ctx.Input<Tensor>("FN");
68+
PADDLE_ENFORCE(tp_in->IsInitialized(), "true_positive is not inited!");
69+
PADDLE_ENFORCE(fp_in->IsInitialized(), "false_negative is not inited!");
70+
PADDLE_ENFORCE(tn_in->IsInitialized(), "true_negative is not inited!");
71+
PADDLE_ENFORCE(fn_in->IsInitialized(), "false_positive is not inited!");
72+
PADDLE_ENFORCE_EQ(tp_in->numel(), num_thresholds, "");
73+
PADDLE_ENFORCE_EQ(fp_in->numel(), num_thresholds, "");
74+
PADDLE_ENFORCE_EQ(tn_in->numel(), num_thresholds, "");
75+
PADDLE_ENFORCE_EQ(fn_in->numel(), num_thresholds, "");
76+
6377
auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
6478
auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
6579
auto* tn_data = true_negative->mutable_data<int64_t>(ctx.GetPlace());

paddle/fluid/operators/math/cpu_vec_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include <sys/time.h>
1616
#include <cmath>
1717
#include <cstring>
18+
#include <random>
1819
#include <vector>
1920
#include "gflags/gflags.h"
2021
#include "glog/logging.h"

python/paddle/fluid/layers/metric_op.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,14 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
119119
helper = LayerHelper("auc", **locals())
120120
auc_out = helper.create_tmp_variable(dtype="float64")
121121
# make tp, tn, fp, fn persistable, so that can accumulate all batches.
122-
tp = helper.create_global_variable(persistable=True, dtype='int64')
123-
tn = helper.create_global_variable(persistable=True, dtype='int64')
124-
fp = helper.create_global_variable(persistable=True, dtype='int64')
125-
fn = helper.create_global_variable(persistable=True, dtype='int64')
122+
tp = helper.create_global_variable(
123+
persistable=True, dtype='int64', shape=[num_thresholds])
124+
tn = helper.create_global_variable(
125+
persistable=True, dtype='int64', shape=[num_thresholds])
126+
fp = helper.create_global_variable(
127+
persistable=True, dtype='int64', shape=[num_thresholds])
128+
fn = helper.create_global_variable(
129+
persistable=True, dtype='int64', shape=[num_thresholds])
126130
for var in [tp, tn, fp, fn]:
127131
helper.set_variable_initializer(
128132
var, Constant(

0 commit comments

Comments
 (0)