Skip to content

Commit 536d9a3

Browse files
authored
[cherry-pick]fix error message & label check in softmax_with_cross_entropy (#31123)
* fix error message & label check in softmax_with_cross_entropy * fix error message & label check in softmax_with_cross_entropy * fix print comment * fix ignore_index check in softmax_with_cross_entropy
1 parent 84a5ed9 commit 536d9a3

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

paddle/fluid/operators/softmax_with_cross_entropy_op.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,13 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
253253
public:
254254
HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
255255
T* log_softmax, int64_t d,
256-
int axis_dim)
256+
int axis_dim, int ignore_idx)
257257
: labels_(labels),
258258
loss_(loss),
259259
log_softmax_(log_softmax),
260260
d_(d),
261-
axis_dim_(axis_dim) {}
261+
axis_dim_(axis_dim),
262+
ignore_idx_(ignore_idx) {}
262263

263264
__device__ void operator()(int64_t idx) const {
264265
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
@@ -268,6 +269,11 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
268269
int64_t idx_remain = idx % remain;
269270
// labels, loss view as [n, remain]
270271
int64_t idx_lbl = idx_n * remain + idx_remain;
272+
PADDLE_ENFORCE(labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_ ||
273+
labels_[idx_lbl] == ignore_idx_,
274+
"The value of label[%ld] expected >= 0 and < %ld, or == %d,"
275+
"but got %ld. Please check input value.",
276+
idx_lbl, d_, ignore_idx_, labels_[idx_lbl]);
271277
// It also would ignore labels not in range(class_num).
272278
if (idx_axis != labels_[idx_lbl]) {
273279
log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
@@ -284,6 +290,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
284290
T* log_softmax_;
285291
int64_t d_;
286292
int axis_dim_;
293+
int ignore_idx_;
287294
};
288295

289296
template <typename T>
@@ -351,7 +358,7 @@ static void HardLabelSoftmaxWithCrossEntropy(
351358
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
352359
} else { \
353360
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
354-
labels_data, loss_data, softmax_data, d, axis_dim)); \
361+
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
355362
} \
356363
} break
357364

0 commit comments

Comments
 (0)