We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fa6c2b5 commit cc9028bCopy full SHA for cc9028b
paddle/fluid/operators/metrics/auc_op.h
@@ -75,8 +75,13 @@ class AucKernel : public framework::OpKernel<T> {
75
const auto *label_data = label->data<int64_t>();
76
77
for (size_t i = 0; i < batch_size; i++) {
78
- uint32_t binIdx = static_cast<uint32_t>(
79
- inference_data[i * inference_width + 1] * num_thresholds);
+ auto predict_data = inference_data[i * inference_width + 1];
+ PADDLE_ENFORCE_LE(predict_data, 1,
80
+ "The predict data must less or equal 1.");
81
+ PADDLE_ENFORCE_GE(predict_data, 0,
82
+ "The predict data must gather or equal 0.");
83
+
84
+ uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
85
if (label_data[i]) {
86
(*stat_pos)[binIdx] += 1.0;
87
} else {
0 commit comments