Skip to content

Commit 2b58c62

Browse files
authored
Update auc op (#12199)
fix AUC op optimize it's test
1 parent 37713f2 commit 2b58c62

File tree

5 files changed

+52
-108
lines changed

5 files changed

+52
-108
lines changed

paddle/fluid/operators/auc_op.cc

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ class AucOp : public framework::OperatorWithKernel {
2424

2525
protected:
2626
void InferShape(framework::InferShapeContext *ctx) const override {
27-
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input of Out should not be null.");
28-
PADDLE_ENFORCE(ctx->HasInput("Indices"),
29-
"Input of Indices should not be null.");
27+
PADDLE_ENFORCE(ctx->HasInput("Predict"),
28+
"Input of Out should not be null.");
3029
PADDLE_ENFORCE(ctx->HasInput("Label"),
3130
"Input of Label should not be null.");
32-
auto inference_height = ctx->GetInputDim("Out")[0];
31+
auto predict_width = ctx->GetInputDim("Predict")[1];
32+
PADDLE_ENFORCE_EQ(predict_width, 2, "Only support binary classification");
33+
auto predict_height = ctx->GetInputDim("Predict")[0];
3334
auto label_height = ctx->GetInputDim("Label")[0];
3435

35-
PADDLE_ENFORCE_EQ(inference_height, label_height,
36+
PADDLE_ENFORCE_EQ(predict_height, label_height,
3637
"Out and Label should have same height.");
3738

3839
int num_thres = ctx->Attrs().Get<int>("num_thresholds");
@@ -43,33 +44,28 @@ class AucOp : public framework::OperatorWithKernel {
4344
ctx->SetOutputDim("FPOut", {num_thres});
4445
ctx->SetOutputDim("FNOut", {num_thres});
4546

46-
ctx->ShareLoD("Out", /*->*/ "AUC");
47+
ctx->ShareLoD("Predict", /*->*/ "AUC");
4748
}
4849

4950
protected:
5051
framework::OpKernelType GetExpectedKernelType(
5152
const framework::ExecutionContext &ctx) const override {
5253
return framework::OpKernelType(
53-
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
54+
framework::ToDataType(ctx.Input<Tensor>("Predict")->type()),
5455
ctx.device_context());
5556
}
5657
};
5758

5859
class AucOpMaker : public framework::OpProtoAndCheckerMaker {
5960
public:
6061
void Make() override {
61-
AddInput("Out",
62-
"A floating point 2D tensor, values are in the range [0, 1]."
63-
"Each row is sorted in descending order. This input should be the"
64-
"output of topk."
62+
AddInput("Predict",
63+
"A floating point 2D tensor with shape [batch_size, 2], values "
64+
"are in the range [0, 1]."
6565
"Typically, this tensor indicates the probability of each label");
66-
AddInput("Indices",
67-
"An int 2D tensor, indicating the indices of original"
68-
"tensor before sorting. Typically, this tensor indicates which "
69-
"label the probability stands for.");
7066
AddInput("Label",
71-
"A 2D int tensor indicating the label of the training data."
72-
"The height is batch size and width is always 1.");
67+
"A 2D int tensor indicating the label of the training data. "
68+
"shape: [batch_size, 1]");
7369
AddInput("TP", "True-Positive value.");
7470
AddInput("FP", "False-Positive value.");
7571
AddInput("TN", "True-Negative value.");

paddle/fluid/operators/auc_op.h

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ template <typename DeviceContext, typename T>
3131
class AucKernel : public framework::OpKernel<T> {
3232
public:
3333
void Compute(const framework::ExecutionContext& ctx) const override {
34-
auto* inference = ctx.Input<Tensor>("Out");
34+
auto* predict = ctx.Input<Tensor>("Predict");
3535
auto* label = ctx.Input<Tensor>("Label");
3636
auto* auc = ctx.Output<Tensor>("AUC");
3737
// Only use output var for now, make sure it's persistable and
@@ -41,45 +41,44 @@ class AucKernel : public framework::OpKernel<T> {
4141
auto* true_negative = ctx.Output<Tensor>("TNOut");
4242
auto* false_negative = ctx.Output<Tensor>("FNOut");
4343

44-
float* auc_data = auc->mutable_data<float>(ctx.GetPlace());
44+
auto* auc_data = auc->mutable_data<double>(ctx.GetPlace());
4545

4646
std::string curve = ctx.Attr<std::string>("curve");
4747
int num_thresholds = ctx.Attr<int>("num_thresholds");
48-
std::vector<float> thresholds_list;
48+
std::vector<double> thresholds_list;
4949
thresholds_list.reserve(num_thresholds);
5050
for (int i = 1; i < num_thresholds - 1; i++) {
51-
thresholds_list[i] = static_cast<float>(i) / (num_thresholds - 1);
51+
thresholds_list[i] = static_cast<double>(i) / (num_thresholds - 1);
5252
}
53-
const float kEpsilon = 1e-7;
53+
const double kEpsilon = 1e-7;
5454
thresholds_list[0] = 0.0f - kEpsilon;
5555
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
5656

57-
size_t batch_size = inference->dims()[0];
58-
size_t inference_width = inference->dims()[1];
57+
size_t batch_size = predict->dims()[0];
58+
size_t inference_width = predict->dims()[1];
5959

60-
const T* inference_data = inference->data<T>();
61-
const int64_t* label_data = label->data<int64_t>();
60+
const T* inference_data = predict->data<T>();
61+
const auto* label_data = label->data<int64_t>();
6262

6363
auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
6464
auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
6565
auto* tn_data = true_negative->mutable_data<int64_t>(ctx.GetPlace());
6666
auto* fp_data = false_positive->mutable_data<int64_t>(ctx.GetPlace());
6767

6868
for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) {
69-
// caculate TP, FN, TN, FP for current thresh
69+
// calculate TP, FN, TN, FP for current thresh
7070
int64_t tp = 0, fn = 0, tn = 0, fp = 0;
7171
for (size_t i = 0; i < batch_size; i++) {
72-
// NOTE: label_data used as bool, labels >0 will be treated as true.
72+
// NOTE: label_data used as bool, labels > 0 will be treated as true.
7373
if (label_data[i]) {
74-
// use first(max) data in each row
75-
if (inference_data[i * inference_width] >=
74+
if (inference_data[i * inference_width + 1] >=
7675
(thresholds_list[idx_thresh])) {
7776
tp++;
7877
} else {
7978
fn++;
8079
}
8180
} else {
82-
if (inference_data[i * inference_width] >=
81+
if (inference_data[i * inference_width + 1] >=
8382
(thresholds_list[idx_thresh])) {
8483
fp++;
8584
} else {
@@ -94,21 +93,21 @@ class AucKernel : public framework::OpKernel<T> {
9493
fp_data[idx_thresh] += fp;
9594
}
9695
// epsilon to avoid divide by zero.
97-
float epsilon = 1e-6;
96+
double epsilon = 1e-6;
9897
// Riemann sum to caculate auc.
9998
Tensor tp_rate, fp_rate, rec_rate;
10099
tp_rate.Resize({num_thresholds});
101100
fp_rate.Resize({num_thresholds});
102101
rec_rate.Resize({num_thresholds});
103-
float* tp_rate_data = tp_rate.mutable_data<float>(ctx.GetPlace());
104-
float* fp_rate_data = fp_rate.mutable_data<float>(ctx.GetPlace());
105-
float* rec_rate_data = rec_rate.mutable_data<float>(ctx.GetPlace());
102+
auto* tp_rate_data = tp_rate.mutable_data<double>(ctx.GetPlace());
103+
auto* fp_rate_data = fp_rate.mutable_data<double>(ctx.GetPlace());
104+
auto* rec_rate_data = rec_rate.mutable_data<double>(ctx.GetPlace());
106105
for (int i = 0; i < num_thresholds; i++) {
107-
tp_rate_data[i] = (static_cast<float>(tp_data[i]) + epsilon) /
106+
tp_rate_data[i] = (static_cast<double>(tp_data[i]) + epsilon) /
108107
(tp_data[i] + fn_data[i] + epsilon);
109108
fp_rate_data[i] =
110-
static_cast<float>(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon);
111-
rec_rate_data[i] = (static_cast<float>(tp_data[i]) + epsilon) /
109+
static_cast<double>(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon);
110+
rec_rate_data[i] = (static_cast<double>(tp_data[i]) + epsilon) /
112111
(tp_data[i] + fp_data[i] + epsilon);
113112
}
114113
*auc_data = 0.0f;

python/paddle/fluid/layers/metric_op.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -114,23 +114,13 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
114114
prediction = network(image, is_infer=True)
115115
auc_out=fluid.layers.auc(input=prediction, label=label)
116116
"""
117-
118-
warnings.warn(
119-
"This interface is not recommended, fluid.layers.auc compute the auc at every minibatch, \
120-
but can not aggregate them and get the pass AUC, because pass \
121-
auc can not be averaged with weighted from the minibatch auc value. \
122-
Please use fluid.metrics.Auc, it can compute the auc value via Python natively, \
123-
which can get every minibatch and every pass auc value.", Warning)
124117
helper = LayerHelper("auc", **locals())
125-
topk_out = helper.create_tmp_variable(dtype=input.dtype)
126-
topk_indices = helper.create_tmp_variable(dtype="int64")
127-
topk_out, topk_indices = nn.topk(input, k=k)
128-
auc_out = helper.create_tmp_variable(dtype="float32")
118+
auc_out = helper.create_tmp_variable(dtype="float64")
129119
# make tp, tn, fp, fn persistable, so that can accumulate all batches.
130-
tp = helper.create_global_variable(persistable=True)
131-
tn = helper.create_global_variable(persistable=True)
132-
fp = helper.create_global_variable(persistable=True)
133-
fn = helper.create_global_variable(persistable=True)
120+
tp = helper.create_global_variable(persistable=True, dtype='int64')
121+
tn = helper.create_global_variable(persistable=True, dtype='int64')
122+
fp = helper.create_global_variable(persistable=True, dtype='int64')
123+
fn = helper.create_global_variable(persistable=True, dtype='int64')
134124
for var in [tp, tn, fp, fn]:
135125
helper.set_variable_initializer(
136126
var, Constant(
@@ -139,8 +129,7 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
139129
helper.append_op(
140130
type="auc",
141131
inputs={
142-
"Out": [topk_out],
143-
"Indices": [topk_indices],
132+
"Predict": [input],
144133
"Label": [label],
145134
"TP": [tp],
146135
"TN": [tn],
@@ -156,4 +145,4 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
156145
"FPOut": [fp],
157146
"FNOut": [fn]
158147
})
159-
return auc_out
148+
return auc_out, [tp, tn, fp, fn]

python/paddle/fluid/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def update(self, preds, labels):
591591
for i in range(self._num_thresholds - 2)]
592592
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
593593

594-
# caculate TP, FN, TN, FP count
594+
# calculate TP, FN, TN, FP count
595595
for idx_thresh, thresh in enumerate(thresholds):
596596
tp, fn, tn, fp = 0, 0, 0, 0
597597
for i, lbl in enumerate(labels):

python/paddle/fluid/tests/unittests/test_auc_op.py

Lines changed: 11 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
import unittest
1616
import numpy as np
1717
from op_test import OpTest
18+
from paddle.fluid import metrics
1819

1920

2021
class TestAucOp(OpTest):
2122
def setUp(self):
2223
self.op_type = "auc"
2324
pred = np.random.random((128, 2)).astype("float32")
24-
indices = np.random.randint(0, 2, (128, 2))
2525
labels = np.random.randint(0, 2, (128, 1))
2626
num_thresholds = 200
2727
tp = np.zeros((num_thresholds, )).astype("int64")
@@ -30,66 +30,26 @@ def setUp(self):
3030
fn = np.zeros((num_thresholds, )).astype("int64")
3131

3232
self.inputs = {
33-
'Out': pred,
34-
'Indices': indices,
33+
'Predict': pred,
3534
'Label': labels,
3635
'TP': tp,
3736
'TN': tn,
3837
'FP': fp,
3938
'FN': fn
4039
}
4140
self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds}
42-
# NOTE: sklearn use a different way to generate thresholds
43-
# which will cause the result differs slightly:
44-
# from sklearn.metrics import roc_curve, auc
45-
# fpr, tpr, thresholds = roc_curve(labels, pred)
46-
# auc_value = auc(fpr, tpr)
47-
# we caculate AUC again using numpy for testing
48-
kepsilon = 1e-7 # to account for floating point imprecisions
49-
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
50-
for i in range(num_thresholds - 2)]
51-
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
5241

53-
# caculate TP, FN, TN, FP count
54-
tp_list = np.ndarray((num_thresholds, ))
55-
fn_list = np.ndarray((num_thresholds, ))
56-
tn_list = np.ndarray((num_thresholds, ))
57-
fp_list = np.ndarray((num_thresholds, ))
58-
for idx_thresh, thresh in enumerate(thresholds):
59-
tp, fn, tn, fp = 0, 0, 0, 0
60-
for i, lbl in enumerate(labels):
61-
if lbl:
62-
if pred[i, 0] >= thresh:
63-
tp += 1
64-
else:
65-
fn += 1
66-
else:
67-
if pred[i, 0] >= thresh:
68-
fp += 1
69-
else:
70-
tn += 1
71-
tp_list[idx_thresh] = tp
72-
fn_list[idx_thresh] = fn
73-
tn_list[idx_thresh] = tn
74-
fp_list[idx_thresh] = fp
75-
76-
epsilon = 1e-6
77-
tpr = (tp_list.astype("float32") + epsilon) / (
78-
tp_list + fn_list + epsilon)
79-
fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon)
80-
rec = (tp_list.astype("float32") + epsilon) / (
81-
tp_list + fp_list + epsilon)
82-
83-
x = fpr[:num_thresholds - 1] - fpr[1:]
84-
y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0
85-
auc_value = np.sum(x * y)
42+
python_auc = metrics.Auc(name="auc",
43+
curve='ROC',
44+
num_thresholds=num_thresholds)
45+
python_auc.update(pred, labels)
8646

8747
self.outputs = {
88-
'AUC': auc_value,
89-
'TPOut': tp_list,
90-
'FNOut': fn_list,
91-
'TNOut': tn_list,
92-
'FPOut': fp_list
48+
'AUC': python_auc.eval(),
49+
'TPOut': python_auc.tp_list,
50+
'FNOut': python_auc.fn_list,
51+
'TNOut': python_auc.tn_list,
52+
'FPOut': python_auc.fp_list
9353
}
9454

9555
def test_check_output(self):

0 commit comments

Comments
 (0)