Skip to content

Commit cc9028b

Browse files
authored
cherry-pick enforce for auc (#14687) (#14694)
* add enforce for AUC, test=release/1.2
1 parent fa6c2b5 commit cc9028b

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

paddle/fluid/operators/metrics/auc_op.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,13 @@ class AucKernel : public framework::OpKernel<T> {
7575
const auto *label_data = label->data<int64_t>();
7676

7777
for (size_t i = 0; i < batch_size; i++) {
78-
uint32_t binIdx = static_cast<uint32_t>(
79-
inference_data[i * inference_width + 1] * num_thresholds);
78+
auto predict_data = inference_data[i * inference_width + 1];
79+
PADDLE_ENFORCE_LE(predict_data, 1,
80+
"The predict data must less or equal 1.");
81+
PADDLE_ENFORCE_GE(predict_data, 0,
82+
"The predict data must gather or equal 0.");
83+
84+
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
8085
if (label_data[i]) {
8186
(*stat_pos)[binIdx] += 1.0;
8287
} else {

0 commit comments

Comments
 (0)