Skip to content

Commit d1e2efa

Browse files
authored
reimplement auc in fluid (#13167)
* reimplement auc in pyton * reimplement auc in fluid * add auc unittest * replace new auc in layers * add batch Auc in Fluid * name formated
1 parent f94fdea commit d1e2efa

File tree

6 files changed

+134
-181
lines changed

6 files changed

+134
-181
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kw
312312
paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
313313
paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
314314
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
315-
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk'], varargs=None, keywords=None, defaults=('ROC', 200, 1))
315+
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk'], varargs=None, keywords=None, defaults=('ROC', 4095, 1))
316316
paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
317317
paddle.fluid.layers.natural_exp_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
318318
paddle.fluid.layers.inverse_time_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))

paddle/fluid/operators/auc_op.cc

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/auc_op.h"
16-
#include <string>
1716

1817
namespace paddle {
1918
namespace operators {
@@ -36,15 +35,12 @@ class AucOp : public framework::OperatorWithKernel {
3635
PADDLE_ENFORCE_EQ(predict_height, label_height,
3736
"Out and Label should have same height.");
3837

39-
int num_thres = ctx->Attrs().Get<int>("num_thresholds");
38+
int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1;
4039

4140
ctx->SetOutputDim("AUC", {1});
42-
ctx->SetOutputDim("TPOut", {num_thres});
43-
ctx->SetOutputDim("TNOut", {num_thres});
44-
ctx->SetOutputDim("FPOut", {num_thres});
45-
ctx->SetOutputDim("FNOut", {num_thres});
46-
47-
ctx->ShareLoD("Predict", /*->*/ "AUC");
41+
ctx->SetOutputDim("BatchAUC", {1});
42+
ctx->SetOutputDim("StatPosOut", {num_pred_buckets});
43+
ctx->SetOutputDim("StatNegOut", {num_pred_buckets});
4844
}
4945

5046
protected:
@@ -66,25 +62,24 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
6662
AddInput("Label",
6763
"A 2D int tensor indicating the label of the training data. "
6864
"shape: [batch_size, 1]");
69-
AddInput("TP", "True-Positive value.");
70-
AddInput("FP", "False-Positive value.");
71-
AddInput("TN", "True-Negative value.");
72-
AddInput("FN", "False-Negative value.");
7365
// TODO(typhoonzero): support weight input
66+
AddInput("StatPos", "Statistic value when label = 1");
67+
AddInput("StatNeg", "Statistic value when label = 0");
68+
7469
AddOutput("AUC",
7570
"A scalar representing the "
7671
"current area-under-the-curve.");
77-
AddOutput("TPOut", "True-Positive value.");
78-
AddOutput("FPOut", "False-Positive value.");
79-
AddOutput("TNOut", "True-Negative value.");
80-
AddOutput("FNOut", "False-Negative value.");
72+
AddOutput("BatchAUC", "The AUC for current batch");
73+
AddOutput("StatPosOut", "Statistic value when label = 1");
74+
AddOutput("StatNegOut", "Statistic value when label = 0");
8175

8276
AddAttr<std::string>("curve", "Curve type, can be 'ROC' or 'PR'.")
8377
.SetDefault("ROC");
78+
8479
AddAttr<int>("num_thresholds",
8580
"The number of thresholds to use when discretizing the"
8681
" roc curve.")
87-
.SetDefault(200);
82+
.SetDefault((2 << 12) - 1);
8883

8984
AddComment(R"DOC(
9085
Area Under The Curve (AUC) Operator.

paddle/fluid/operators/auc_op.h

Lines changed: 66 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -13,116 +13,95 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
1617
#include <string>
1718
#include <vector>
18-
#include "paddle/fluid/framework/eigen.h"
1919
#include "paddle/fluid/framework/op_registry.h"
2020

2121
namespace paddle {
2222
namespace operators {
2323

2424
using Tensor = framework::Tensor;
2525

26-
template <typename T, int MajorType = Eigen::RowMajor,
27-
typename IndexType = Eigen::DenseIndex>
28-
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
29-
3026
template <typename DeviceContext, typename T>
3127
class AucKernel : public framework::OpKernel<T> {
3228
public:
33-
void Compute(const framework::ExecutionContext& ctx) const override {
34-
auto* predict = ctx.Input<Tensor>("Predict");
35-
auto* label = ctx.Input<Tensor>("Label");
36-
auto* auc = ctx.Output<Tensor>("AUC");
29+
void Compute(const framework::ExecutionContext &ctx) const override {
30+
auto *predict = ctx.Input<Tensor>("Predict");
31+
auto *label = ctx.Input<Tensor>("Label");
32+
33+
std::string curve = ctx.Attr<std::string>("curve");
34+
int num_thresholds = ctx.Attr<int>("num_thresholds");
35+
int num_pred_buckets = num_thresholds + 1;
36+
3737
// Only use output var for now, make sure it's persistable and
3838
// not cleaned up for each batch.
39-
auto* true_positive = ctx.Output<Tensor>("TPOut");
40-
auto* false_positive = ctx.Output<Tensor>("FPOut");
41-
auto* true_negative = ctx.Output<Tensor>("TNOut");
42-
auto* false_negative = ctx.Output<Tensor>("FNOut");
39+
auto *auc = ctx.Output<Tensor>("AUC");
40+
auto *stat_pos = ctx.Output<Tensor>("StatPosOut");
41+
auto *stat_neg = ctx.Output<Tensor>("StatNegOut");
4342

44-
auto* auc_data = auc->mutable_data<double>(ctx.GetPlace());
43+
auto *stat_pos_data = stat_pos->mutable_data<int64_t>(ctx.GetPlace());
44+
auto *stat_neg_data = stat_neg->mutable_data<int64_t>(ctx.GetPlace());
45+
calcAuc(ctx, label, predict, stat_pos_data, stat_neg_data, num_thresholds,
46+
auc);
4547

46-
std::string curve = ctx.Attr<std::string>("curve");
47-
int num_thresholds = ctx.Attr<int>("num_thresholds");
48-
std::vector<double> thresholds_list;
49-
thresholds_list.reserve(num_thresholds);
50-
for (int i = 1; i < num_thresholds - 1; i++) {
51-
thresholds_list[i] = static_cast<double>(i) / (num_thresholds - 1);
52-
}
53-
const double kEpsilon = 1e-7;
54-
thresholds_list[0] = 0.0f - kEpsilon;
55-
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
48+
auto *batch_auc = ctx.Output<Tensor>("BatchAUC");
49+
std::vector<int64_t> stat_pos_batch(num_pred_buckets, 0);
50+
std::vector<int64_t> stat_neg_batch(num_pred_buckets, 0);
51+
calcAuc(ctx, label, predict, stat_pos_batch.data(), stat_neg_batch.data(),
52+
num_thresholds, batch_auc);
53+
}
5654

55+
private:
56+
inline static double trapezoidArea(double X1, double X2, double Y1,
57+
double Y2) {
58+
return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0;
59+
}
60+
61+
inline static void calcAuc(const framework::ExecutionContext &ctx,
62+
const framework::Tensor *label,
63+
const framework::Tensor *predict,
64+
int64_t *stat_pos, int64_t *stat_neg,
65+
int num_thresholds,
66+
framework::Tensor *auc_tensor) {
5767
size_t batch_size = predict->dims()[0];
5868
size_t inference_width = predict->dims()[1];
69+
const T *inference_data = predict->data<T>();
70+
const auto *label_data = label->data<int64_t>();
71+
72+
auto *auc = auc_tensor->mutable_data<double>(ctx.GetPlace());
5973

60-
const T* inference_data = predict->data<T>();
61-
const auto* label_data = label->data<int64_t>();
62-
63-
auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
64-
auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
65-
auto* tn_data = true_negative->mutable_data<int64_t>(ctx.GetPlace());
66-
auto* fp_data = false_positive->mutable_data<int64_t>(ctx.GetPlace());
67-
68-
for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) {
69-
// calculate TP, FN, TN, FP for current thresh
70-
int64_t tp = 0, fn = 0, tn = 0, fp = 0;
71-
for (size_t i = 0; i < batch_size; i++) {
72-
// NOTE: label_data used as bool, labels > 0 will be treated as true.
73-
if (label_data[i]) {
74-
if (inference_data[i * inference_width + 1] >=
75-
(thresholds_list[idx_thresh])) {
76-
tp++;
77-
} else {
78-
fn++;
79-
}
80-
} else {
81-
if (inference_data[i * inference_width + 1] >=
82-
(thresholds_list[idx_thresh])) {
83-
fp++;
84-
} else {
85-
tn++;
86-
}
87-
}
74+
for (size_t i = 0; i < batch_size; i++) {
75+
uint32_t binIdx = static_cast<uint32_t>(
76+
inference_data[i * inference_width + 1] * num_thresholds);
77+
if (label_data[i]) {
78+
stat_pos[binIdx] += 1.0;
79+
} else {
80+
stat_neg[binIdx] += 1.0;
8881
}
89-
// store rates
90-
tp_data[idx_thresh] += tp;
91-
fn_data[idx_thresh] += fn;
92-
tn_data[idx_thresh] += tn;
93-
fp_data[idx_thresh] += fp;
9482
}
95-
// epsilon to avoid divide by zero.
96-
double epsilon = 1e-6;
97-
// Riemann sum to caculate auc.
98-
Tensor tp_rate, fp_rate, rec_rate;
99-
tp_rate.Resize({num_thresholds});
100-
fp_rate.Resize({num_thresholds});
101-
rec_rate.Resize({num_thresholds});
102-
auto* tp_rate_data = tp_rate.mutable_data<double>(ctx.GetPlace());
103-
auto* fp_rate_data = fp_rate.mutable_data<double>(ctx.GetPlace());
104-
auto* rec_rate_data = rec_rate.mutable_data<double>(ctx.GetPlace());
105-
for (int i = 0; i < num_thresholds; i++) {
106-
tp_rate_data[i] = (static_cast<double>(tp_data[i]) + epsilon) /
107-
(tp_data[i] + fn_data[i] + epsilon);
108-
fp_rate_data[i] =
109-
static_cast<double>(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon);
110-
rec_rate_data[i] = (static_cast<double>(tp_data[i]) + epsilon) /
111-
(tp_data[i] + fp_data[i] + epsilon);
83+
84+
*auc = 0.0f;
85+
86+
double totPos = 0.0;
87+
double totNeg = 0.0;
88+
double totPosPrev = 0.0;
89+
double totNegPrev = 0.0;
90+
91+
int idx = num_thresholds;
92+
93+
while (idx >= 0) {
94+
totPosPrev = totPos;
95+
totNegPrev = totNeg;
96+
totPos += stat_pos[idx];
97+
totNeg += stat_neg[idx];
98+
*auc += trapezoidArea(totNeg, totNegPrev, totPos, totPosPrev);
99+
100+
--idx;
112101
}
113-
*auc_data = 0.0f;
114-
if (curve == "ROC") {
115-
for (int i = 0; i < num_thresholds - 1; i++) {
116-
auto dx = fp_rate_data[i] - fp_rate_data[i + 1];
117-
auto y = (tp_rate_data[i] + tp_rate_data[i + 1]) / 2.0f;
118-
*auc_data = *auc_data + dx * y;
119-
}
120-
} else if (curve == "PR") {
121-
for (int i = 1; i < num_thresholds; i++) {
122-
auto dx = tp_rate_data[i] - tp_rate_data[i - 1];
123-
auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f;
124-
*auc_data = *auc_data + dx * y;
125-
}
102+
103+
if (totPos > 0.0 && totNeg > 0.0) {
104+
*auc = *auc / totPos / totNeg;
126105
}
127106
}
128107
};

python/paddle/fluid/layers/metric_op.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def accuracy(input, label, k=1, correct=None, total=None):
7878
return acc_out
7979

8080

81-
def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
81+
def auc(input, label, curve='ROC', num_thresholds=2**12 - 1, topk=1):
8282
"""
8383
**Area Under the Curve (AUC) Layer**
8484
@@ -118,16 +118,14 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
118118
"""
119119
helper = LayerHelper("auc", **locals())
120120
auc_out = helper.create_tmp_variable(dtype="float64")
121+
batch_auc_out = helper.create_tmp_variable(dtype="float64")
121122
# make tp, tn, fp, fn persistable, so that can accumulate all batches.
122-
tp = helper.create_global_variable(
123-
persistable=True, dtype='int64', shape=[num_thresholds])
124-
tn = helper.create_global_variable(
125-
persistable=True, dtype='int64', shape=[num_thresholds])
126-
fp = helper.create_global_variable(
127-
persistable=True, dtype='int64', shape=[num_thresholds])
128-
fn = helper.create_global_variable(
129-
persistable=True, dtype='int64', shape=[num_thresholds])
130-
for var in [tp, tn, fp, fn]:
123+
stat_pos = helper.create_global_variable(
124+
persistable=True, dtype='int64', shape=[num_thresholds + 1])
125+
stat_neg = helper.create_global_variable(
126+
persistable=True, dtype='int64', shape=[num_thresholds + 1])
127+
128+
for var in [stat_pos, stat_neg]:
131129
helper.set_variable_initializer(
132130
var, Constant(
133131
value=0.0, force_cpu=True))
@@ -137,18 +135,15 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
137135
inputs={
138136
"Predict": [input],
139137
"Label": [label],
140-
"TP": [tp],
141-
"TN": [tn],
142-
"FP": [fp],
143-
"FN": [fn]
138+
"StatPos": [stat_pos],
139+
"StatNeg": [stat_neg]
144140
},
145141
attrs={"curve": curve,
146142
"num_thresholds": num_thresholds},
147143
outputs={
148144
"AUC": [auc_out],
149-
"TPOut": [tp],
150-
"TNOut": [tn],
151-
"FPOut": [fp],
152-
"FNOut": [fn]
145+
"BatchAUC": [batch_auc_out],
146+
"StatPosOut": [stat_pos],
147+
"StatNegOut": [stat_neg]
153148
})
154-
return auc_out, [tp, tn, fp, fn]
149+
return auc_out, batch_auc_out, [stat_pos, stat_neg]

0 commit comments

Comments
 (0)