Skip to content

Commit 3ba75d4

Browse files
authored
Check label range in cross entropy calculation. (#10954)
1 parent 0c44efb commit 3ba75d4

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

paddle/fluid/operators/math/cross_entropy.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
4646

4747
const int64_t* label_data = labels->data<int64_t>();
4848
for (int i = 0; i < batch_size; ++i) {
49-
int index = i * class_num + label_data[i];
49+
int lbl = label_data[i];
50+
PADDLE_ENFORCE_GE(lbl, 0);
51+
PADDLE_ENFORCE_LT(lbl, class_num);
52+
int index = i * class_num + lbl;
5053
loss_data[i] = -math::TolerableValue<T>()(std::log(prob_data[index]));
5154
}
5255
}

0 commit comments

Comments
 (0)