Skip to content

Commit faf8ad2

Browse files
author
Bai Yifan
authored
Add ignore_index in cross_entropy op (#13217)
* add ignore index * update api.spec * enhance softmax_with_cross_entropy
1 parent 94b66bd commit faf8ad2

13 files changed

+148
-29
lines changed

paddle/fluid/API.spec

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_att
100100
paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,))
101101
paddle.fluid.layers.crf_decoding ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,))
102102
paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None)
103-
paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label'], varargs=None, keywords=None, defaults=(False,))
103+
paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100))
104104
paddle.fluid.layers.square_error_cost ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None)
105105
paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,))
106106
paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None))
@@ -142,7 +142,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's
142142
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
143143
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
144144
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
145-
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label'], varargs=None, keywords=None, defaults=(False,))
145+
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100))
146146
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
147147
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
148148
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))

paddle/fluid/operators/cross_entropy_op.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
138138
"(bool, default false), a flag indicating whether to "
139139
"interpretate the given labels as soft labels.")
140140
.SetDefault(false);
141+
AddAttr<int>("ignore_index",
142+
"(int, default -100), Specifies a target value that is"
143+
"ignored and does not contribute to the input gradient."
144+
"Only valid if soft_label is set to False")
145+
.SetDefault(-100);
141146
AddComment(R"DOC(
142147
CrossEntropy Operator.
143148

paddle/fluid/operators/cross_entropy_op.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
4040

4141
math::CrossEntropyFunctor<DeviceContext, T>()(
4242
ctx.template device_context<DeviceContext>(), &y_2d, &x_2d, &labels_2d,
43-
ctx.Attr<bool>("soft_label"));
43+
ctx.Attr<bool>("soft_label"), ctx.Attr<int>("ignore_index"));
4444
}
4545
};
4646

@@ -74,16 +74,22 @@ class XeGradFunctor {
7474
const T* dy, // NOLINT
7575
const T* x, // NOLINT
7676
const int64_t* label, // NOLINT
77-
size_t num_classes)
78-
: dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {}
77+
size_t num_classes, size_t ignore_index)
78+
: dx_(dx),
79+
dy_(dy),
80+
x_(x),
81+
label_(label),
82+
num_classes_(num_classes),
83+
ignore_index_(ignore_index) {}
7984

8085
HOSTDEVICE void operator()(size_t sample_id) {
8186
auto x_is_true_offset = sample_id * num_classes_ + label_[sample_id];
8287
for (size_t x_offset = sample_id * num_classes_;
8388
x_offset < (sample_id + 1) * num_classes_; ++x_offset) {
84-
dx_[x_offset] = x_offset != x_is_true_offset
85-
? static_cast<T>(0)
86-
: -dy_[sample_id] / x_[x_offset];
89+
dx_[x_offset] =
90+
(x_offset != x_is_true_offset || label_[sample_id] == ignore_index_)
91+
? static_cast<T>(0)
92+
: -dy_[sample_id] / x_[x_offset];
8793
}
8894
}
8995

@@ -93,6 +99,7 @@ class XeGradFunctor {
9399
const T* x_;
94100
const int64_t* label_;
95101
size_t num_classes_;
102+
size_t ignore_index_;
96103
};
97104

98105
template <typename DeviceContext, typename T>
@@ -109,6 +116,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
109116
// unnecessary to convert tensors to 2-D views.
110117
int rank = x->dims().size();
111118
int64_t class_num = x->dims()[rank - 1];
119+
int64_t ignore_index = ctx.Attr<int>("ignore_index");
112120
if (ctx.Attr<bool>("soft_label")) {
113121
XeSoftlabelGradFunctor<T> functor(dx_data, dy->data<T>(), x->data<T>(),
114122
label->data<T>(),
@@ -118,9 +126,9 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
118126
static_cast<size_t>(dx->numel()));
119127
for_range(functor);
120128
} else {
121-
XeGradFunctor<T> functor(dx_data, dy->data<T>(), x->data<T>(),
122-
label->data<int64_t>(),
123-
static_cast<size_t>(class_num));
129+
XeGradFunctor<T> functor(
130+
dx_data, dy->data<T>(), x->data<T>(), label->data<int64_t>(),
131+
static_cast<size_t>(class_num), static_cast<size_t>(ignore_index));
124132
platform::ForRange<DeviceContext> for_range(
125133
ctx.template device_context<DeviceContext>(),
126134
static_cast<size_t>(dy->numel()));

paddle/fluid/operators/math/cross_entropy.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
2828
public:
2929
void operator()(const platform::CPUDeviceContext& ctx, framework::Tensor* out,
3030
const framework::Tensor* prob,
31-
const framework::Tensor* labels, const bool softLabel) {
31+
const framework::Tensor* labels, const bool softLabel,
32+
const int ignore_index) {
3233
const int batch_size = prob->dims()[0];
3334
if (softLabel) {
3435
auto in = EigenMatrix<T>::From(*prob);
@@ -49,8 +50,12 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
4950
int lbl = label_data[i];
5051
PADDLE_ENFORCE_GE(lbl, 0);
5152
PADDLE_ENFORCE_LT(lbl, class_num);
53+
PADDLE_ENFORCE((lbl >= 0 && lbl < class_num) || lbl == ignore_index);
5254
int index = i * class_num + lbl;
53-
loss_data[i] = -math::TolerableValue<T>()(std::log(prob_data[index]));
55+
loss_data[i] =
56+
lbl == ignore_index
57+
? 0
58+
: -math::TolerableValue<T>()(std::log(prob_data[index]));
5459
}
5560
}
5661
}

paddle/fluid/operators/math/cross_entropy.cu

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@ namespace math {
2323
namespace {
2424
template <typename T>
2525
__global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
26-
const int N, const int D) {
26+
const int N, const int D,
27+
const int ignore_index) {
2728
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
2829
i += blockDim.x * gridDim.x) {
29-
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
30-
Y[i] = -math::TolerableValue<T>()(log(X[i * D + label[i]]));
30+
PADDLE_ASSERT(label[i] >= 0 && label[i] < D || label[i] == ignore_index);
31+
Y[i] = ignore_index == label[i]
32+
? 0
33+
: -math::TolerableValue<T>()(log(X[i * D + label[i]]));
3134
}
3235
}
3336

@@ -57,7 +60,8 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
5760
public:
5861
void operator()(const platform::CUDADeviceContext& ctx,
5962
framework::Tensor* out, const framework::Tensor* prob,
60-
const framework::Tensor* labels, bool softLabel) {
63+
const framework::Tensor* labels, bool softLabel,
64+
const int ignore_index) {
6165
const T* prob_data = prob->data<T>();
6266
T* loss_data = out->mutable_data<T>(ctx.GetPlace());
6367

@@ -77,7 +81,8 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
7781
int block = 512;
7882
int grid = (batch_size + block - 1) / block;
7983
CrossEntropyKernel<T><<<grid, block, 0, ctx.stream()>>>(
80-
loss_data, prob_data, label_data, batch_size, class_num);
84+
loss_data, prob_data, label_data, batch_size, class_num,
85+
ignore_index);
8186
}
8287
}
8388
};

paddle/fluid/operators/math/cross_entropy.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class CrossEntropyFunctor {
3838
public:
3939
void operator()(const DeviceContext& context, framework::Tensor* out,
4040
const framework::Tensor* prob,
41-
const framework::Tensor* labels, const bool softLabel);
41+
const framework::Tensor* labels, const bool softLabel,
42+
const int ignore_index);
4243
};
4344
} // namespace math
4445
} // namespace operators

paddle/fluid/operators/softmax_with_cross_entropy_op.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ class SoftmaxWithCrossEntropyOpMaker
4444
"(bool, default: false), A flag to indicate whether to interpretate "
4545
"the given labels as soft labels.")
4646
.SetDefault(false);
47+
AddAttr<int>(
48+
"ignore_index",
49+
"(int, default -100), Specifies a target value that is ignored and"
50+
"does not contribute to the input gradient. Only valid if soft_label"
51+
"is set to False")
52+
.SetDefault(-100);
4753
AddComment(R"DOC(
4854
Softmax With Cross Entropy Operator.
4955

paddle/fluid/operators/softmax_with_cross_entropy_op.cu

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ using Tensor = framework::Tensor;
2626
namespace {
2727
template <typename T>
2828
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
29-
const int batch_size, const int class_num) {
29+
const int batch_size, const int class_num,
30+
const int ignore_index) {
3031
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size;
3132
i += blockDim.x * gridDim.x) {
3233
int idx = i * class_num + labels[i];
@@ -260,6 +261,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
260261
auto* loss_data = loss->mutable_data<T>(context.GetPlace());
261262

262263
auto soft_label = context.Attr<bool>("soft_label");
264+
auto ignore_index = context.Attr<int>("ignore_index");
263265
if (soft_label) {
264266
int batch_size = logits->dims()[0];
265267
int feature_size = logits->dims()[1];
@@ -272,7 +274,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
272274
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(), logits,
273275
softmax);
274276
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
275-
context.cuda_device_context(), loss, softmax, labels, false);
277+
context.cuda_device_context(), loss, softmax, labels, false,
278+
ignore_index);
276279
}
277280
}
278281
};
@@ -295,7 +298,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
295298
const int class_num = logit_grad->dims()[1];
296299
int block = 512;
297300
auto stream = context.cuda_device_context().stream();
298-
301+
auto ignore_index = context.Attr<int>("ignore_index");
299302
if (context.Attr<bool>("soft_label")) {
300303
int grid = (batch_size * class_num + block - 1) / block;
301304
const T* label_data = labels->data<T>();
@@ -305,7 +308,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
305308
int grid = (batch_size + block - 1) / block;
306309
const int64_t* label_data = labels->data<int64_t>();
307310
CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
308-
logit_grad_data, label_data, batch_size, class_num);
311+
logit_grad_data, label_data, batch_size, class_num, ignore_index);
309312
int num = batch_size * class_num;
310313
grid = (num + block - 1) / block;
311314
Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,

paddle/fluid/operators/softmax_with_cross_entropy_op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
4545
math::SoftmaxFunctor<platform::CPUDeviceContext, T>()(dev_ctx, logits,
4646
softmax);
4747
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
48-
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"));
48+
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
49+
context.Attr<int>("ignore_index"));
4950
}
5051
};
5152

python/paddle/fluid/layers/nn.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
968968
return out
969969

970970

971-
def cross_entropy(input, label, soft_label=False):
971+
def cross_entropy(input, label, soft_label=False, ignore_index=-100):
972972
"""
973973
**Cross Entropy Layer**
974974
@@ -1012,7 +1012,10 @@ def cross_entropy(input, label, soft_label=False):
10121012
tensor<float/double> with shape [N x D].
10131013
soft_label (bool): a flag indicating whether to
10141014
interpretate the given labels as soft
1015-
labels, default `False`.
1015+
labels. Default: `False`.
1016+
ignore_index (int): Specifies a target value that is ignored and does
1017+
not contribute to the input gradient. Only valid
1018+
if soft_label is set to False. Default: -100
10161019
10171020
Returns:
10181021
A 2-D tensor with shape [N x 1], the cross entropy loss.
@@ -1037,7 +1040,8 @@ def cross_entropy(input, label, soft_label=False):
10371040
inputs={'X': [input],
10381041
'Label': [label]},
10391042
outputs={'Y': [out]},
1040-
attrs={"soft_label": soft_label})
1043+
attrs={"soft_label": soft_label,
1044+
"ignore_index": ignore_index})
10411045
return out
10421046

10431047

@@ -4242,7 +4246,10 @@ def multiplex(inputs, index):
42424246
return out
42434247

42444248

4245-
def softmax_with_cross_entropy(logits, label, soft_label=False):
4249+
def softmax_with_cross_entropy(logits,
4250+
label,
4251+
soft_label=False,
4252+
ignore_index=-100):
42464253
"""
42474254
**Softmax With Cross Entropy Operator.**
42484255
@@ -4284,6 +4291,10 @@ def softmax_with_cross_entropy(logits, label, soft_label=False):
42844291
soft_label is set to true, Label is a Tensor<float/double> with
42854292
soft_label (bool): A flag to indicate whether to interpretate the given
42864293
labels as soft labels. By default, `soft_label` is set to False.
4294+
ignore_index (int): Specifies a target value that is ignored and does
4295+
not contribute to the input gradient. Only valid
4296+
if soft_label is set to False. Default: -100
4297+
42874298
Returns:
42884299
Variable: The cross entropy loss is a 2-D tensor with shape [N x 1].
42894300
@@ -4305,7 +4316,8 @@ def softmax_with_cross_entropy(logits, label, soft_label=False):
43054316
'Label': label},
43064317
outputs={'Softmax': softmax,
43074318
'Loss': loss},
4308-
attrs={'soft_label': soft_label})
4319+
attrs={'soft_label': soft_label,
4320+
'ignore_index': ignore_index})
43094321
return loss
43104322

43114323

0 commit comments

Comments
 (0)