@@ -13,116 +13,95 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#pragma once
16
+
16
17
#include < string>
17
18
#include < vector>
18
- #include " paddle/fluid/framework/eigen.h"
19
19
#include " paddle/fluid/framework/op_registry.h"
20
20
21
21
namespace paddle {
22
22
namespace operators {
23
23
24
24
using Tensor = framework::Tensor;
25
25
26
- template <typename T, int MajorType = Eigen::RowMajor,
27
- typename IndexType = Eigen::DenseIndex>
28
- using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
29
-
30
26
template <typename DeviceContext, typename T>
31
27
class AucKernel : public framework ::OpKernel<T> {
32
28
public:
33
- void Compute (const framework::ExecutionContext& ctx) const override {
34
- auto * predict = ctx.Input <Tensor>(" Predict" );
35
- auto * label = ctx.Input <Tensor>(" Label" );
36
- auto * auc = ctx.Output <Tensor>(" AUC" );
29
+ void Compute (const framework::ExecutionContext &ctx) const override {
30
+ auto *predict = ctx.Input <Tensor>(" Predict" );
31
+ auto *label = ctx.Input <Tensor>(" Label" );
32
+
33
+ std::string curve = ctx.Attr <std::string>(" curve" );
34
+ int num_thresholds = ctx.Attr <int >(" num_thresholds" );
35
+ int num_pred_buckets = num_thresholds + 1 ;
36
+
37
37
// Only use output var for now, make sure it's persistable and
38
38
// not cleaned up for each batch.
39
- auto * true_positive = ctx.Output <Tensor>(" TPOut" );
40
- auto * false_positive = ctx.Output <Tensor>(" FPOut" );
41
- auto * true_negative = ctx.Output <Tensor>(" TNOut" );
42
- auto * false_negative = ctx.Output <Tensor>(" FNOut" );
39
+ auto *auc = ctx.Output <Tensor>(" AUC" );
40
+ auto *stat_pos = ctx.Output <Tensor>(" StatPosOut" );
41
+ auto *stat_neg = ctx.Output <Tensor>(" StatNegOut" );
43
42
44
- auto * auc_data = auc->mutable_data <double >(ctx.GetPlace ());
43
+ auto *stat_pos_data = stat_pos->mutable_data <int64_t >(ctx.GetPlace ());
44
+ auto *stat_neg_data = stat_neg->mutable_data <int64_t >(ctx.GetPlace ());
45
+ calcAuc (ctx, label, predict, stat_pos_data, stat_neg_data, num_thresholds,
46
+ auc);
45
47
46
- std::string curve = ctx.Attr <std::string>(" curve" );
47
- int num_thresholds = ctx.Attr <int >(" num_thresholds" );
48
- std::vector<double > thresholds_list;
49
- thresholds_list.reserve (num_thresholds);
50
- for (int i = 1 ; i < num_thresholds - 1 ; i++) {
51
- thresholds_list[i] = static_cast <double >(i) / (num_thresholds - 1 );
52
- }
53
- const double kEpsilon = 1e-7 ;
54
- thresholds_list[0 ] = 0 .0f - kEpsilon ;
55
- thresholds_list[num_thresholds - 1 ] = 1 .0f + kEpsilon ;
48
+ auto *batch_auc = ctx.Output <Tensor>(" BatchAUC" );
49
+ std::vector<int64_t > stat_pos_batch (num_pred_buckets, 0 );
50
+ std::vector<int64_t > stat_neg_batch (num_pred_buckets, 0 );
51
+ calcAuc (ctx, label, predict, stat_pos_batch.data (), stat_neg_batch.data (),
52
+ num_thresholds, batch_auc);
53
+ }
56
54
55
+ private:
56
+ inline static double trapezoidArea (double X1, double X2, double Y1,
57
+ double Y2) {
58
+ return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0 ;
59
+ }
60
+
61
+ inline static void calcAuc (const framework::ExecutionContext &ctx,
62
+ const framework::Tensor *label,
63
+ const framework::Tensor *predict,
64
+ int64_t *stat_pos, int64_t *stat_neg,
65
+ int num_thresholds,
66
+ framework::Tensor *auc_tensor) {
57
67
size_t batch_size = predict->dims ()[0 ];
58
68
size_t inference_width = predict->dims ()[1 ];
69
+ const T *inference_data = predict->data <T>();
70
+ const auto *label_data = label->data <int64_t >();
71
+
72
+ auto *auc = auc_tensor->mutable_data <double >(ctx.GetPlace ());
59
73
60
- const T* inference_data = predict->data <T>();
61
- const auto * label_data = label->data <int64_t >();
62
-
63
- auto * tp_data = true_positive->mutable_data <int64_t >(ctx.GetPlace ());
64
- auto * fn_data = false_negative->mutable_data <int64_t >(ctx.GetPlace ());
65
- auto * tn_data = true_negative->mutable_data <int64_t >(ctx.GetPlace ());
66
- auto * fp_data = false_positive->mutable_data <int64_t >(ctx.GetPlace ());
67
-
68
- for (int idx_thresh = 0 ; idx_thresh < num_thresholds; idx_thresh++) {
69
- // calculate TP, FN, TN, FP for current thresh
70
- int64_t tp = 0 , fn = 0 , tn = 0 , fp = 0 ;
71
- for (size_t i = 0 ; i < batch_size; i++) {
72
- // NOTE: label_data used as bool, labels > 0 will be treated as true.
73
- if (label_data[i]) {
74
- if (inference_data[i * inference_width + 1 ] >=
75
- (thresholds_list[idx_thresh])) {
76
- tp++;
77
- } else {
78
- fn++;
79
- }
80
- } else {
81
- if (inference_data[i * inference_width + 1 ] >=
82
- (thresholds_list[idx_thresh])) {
83
- fp++;
84
- } else {
85
- tn++;
86
- }
87
- }
74
+ for (size_t i = 0 ; i < batch_size; i++) {
75
+ uint32_t binIdx = static_cast <uint32_t >(
76
+ inference_data[i * inference_width + 1 ] * num_thresholds);
77
+ if (label_data[i]) {
78
+ stat_pos[binIdx] += 1.0 ;
79
+ } else {
80
+ stat_neg[binIdx] += 1.0 ;
88
81
}
89
- // store rates
90
- tp_data[idx_thresh] += tp;
91
- fn_data[idx_thresh] += fn;
92
- tn_data[idx_thresh] += tn;
93
- fp_data[idx_thresh] += fp;
94
82
}
95
- // epsilon to avoid divide by zero.
96
- double epsilon = 1e-6 ;
97
- // Riemann sum to caculate auc.
98
- Tensor tp_rate, fp_rate, rec_rate;
99
- tp_rate.Resize ({num_thresholds});
100
- fp_rate.Resize ({num_thresholds});
101
- rec_rate.Resize ({num_thresholds});
102
- auto * tp_rate_data = tp_rate.mutable_data <double >(ctx.GetPlace ());
103
- auto * fp_rate_data = fp_rate.mutable_data <double >(ctx.GetPlace ());
104
- auto * rec_rate_data = rec_rate.mutable_data <double >(ctx.GetPlace ());
105
- for (int i = 0 ; i < num_thresholds; i++) {
106
- tp_rate_data[i] = (static_cast <double >(tp_data[i]) + epsilon) /
107
- (tp_data[i] + fn_data[i] + epsilon);
108
- fp_rate_data[i] =
109
- static_cast <double >(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon);
110
- rec_rate_data[i] = (static_cast <double >(tp_data[i]) + epsilon) /
111
- (tp_data[i] + fp_data[i] + epsilon);
83
+
84
+ *auc = 0 .0f ;
85
+
86
+ double totPos = 0.0 ;
87
+ double totNeg = 0.0 ;
88
+ double totPosPrev = 0.0 ;
89
+ double totNegPrev = 0.0 ;
90
+
91
+ int idx = num_thresholds;
92
+
93
+ while (idx >= 0 ) {
94
+ totPosPrev = totPos;
95
+ totNegPrev = totNeg;
96
+ totPos += stat_pos[idx];
97
+ totNeg += stat_neg[idx];
98
+ *auc += trapezoidArea (totNeg, totNegPrev, totPos, totPosPrev);
99
+
100
+ --idx;
112
101
}
113
- *auc_data = 0 .0f ;
114
- if (curve == " ROC" ) {
115
- for (int i = 0 ; i < num_thresholds - 1 ; i++) {
116
- auto dx = fp_rate_data[i] - fp_rate_data[i + 1 ];
117
- auto y = (tp_rate_data[i] + tp_rate_data[i + 1 ]) / 2 .0f ;
118
- *auc_data = *auc_data + dx * y;
119
- }
120
- } else if (curve == " PR" ) {
121
- for (int i = 1 ; i < num_thresholds; i++) {
122
- auto dx = tp_rate_data[i] - tp_rate_data[i - 1 ];
123
- auto y = (rec_rate_data[i] + rec_rate_data[i - 1 ]) / 2 .0f ;
124
- *auc_data = *auc_data + dx * y;
125
- }
102
+
103
+ if (totPos > 0.0 && totNeg > 0.0 ) {
104
+ *auc = *auc / totPos / totNeg;
126
105
}
127
106
}
128
107
};
0 commit comments