@@ -253,12 +253,13 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
253
253
public:
254
254
HardLabelSoftmaxWithCrossEntropyFunctor (const int64_t * labels, T* loss,
255
255
T* log_softmax, int64_t d,
256
- int axis_dim)
256
+ int axis_dim, int ignore_idx )
257
257
: labels_(labels),
258
258
loss_ (loss),
259
259
log_softmax_(log_softmax),
260
260
d_(d),
261
- axis_dim_(axis_dim) {}
261
+ axis_dim_(axis_dim),
262
+ ignore_idx_(ignore_idx) {}
262
263
263
264
__device__ void operator ()(int64_t idx) const {
264
265
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
@@ -268,6 +269,11 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
268
269
int64_t idx_remain = idx % remain;
269
270
// labels, loss view as [n, remain]
270
271
int64_t idx_lbl = idx_n * remain + idx_remain;
272
+ PADDLE_ENFORCE (labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_ ||
273
+ labels_[idx_lbl] == ignore_idx_,
274
+ " The value of label[%ld] expected >= 0 and < %ld, or == %d,"
275
+ " but got %ld. Please check input value." ,
276
+ idx_lbl, d_, ignore_idx_, labels_[idx_lbl]);
271
277
// It also would ignore labels not in range(class_num).
272
278
if (idx_axis != labels_[idx_lbl]) {
273
279
log_softmax_[idx] = exp_on_device (log_softmax_[idx]);
@@ -284,6 +290,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
284
290
T* log_softmax_;
285
291
int64_t d_;
286
292
int axis_dim_;
293
+ int ignore_idx_;
287
294
};
288
295
289
296
template <typename T>
@@ -351,7 +358,7 @@ static void HardLabelSoftmaxWithCrossEntropy(
351
358
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
352
359
} else { \
353
360
for_range (HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
354
- labels_data, loss_data, softmax_data, d, axis_dim)); \
361
+ labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx )); \
355
362
} \
356
363
} break
357
364
0 commit comments