Skip to content

Commit 9effbbb

Browse files
committed
Merge pull request #16798 from phlrain/softmax_cross_support_high_rank
softmax cross entropy support high rank
1 parent 1237dfa commit 9effbbb

File tree

4 files changed

+238
-36
lines changed

4 files changed

+238
-36
lines changed

paddle/fluid/operators/softmax_with_cross_entropy_op.cc

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -106,24 +106,36 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
106106

107107
auto logits_dims = ctx->GetInputDim("Logits");
108108
auto labels_dims = ctx->GetInputDim("Label");
109+
110+
int rank = logits_dims.size();
109111
PADDLE_ENFORCE_EQ(
110-
logits_dims.size(), 2UL,
111-
"The input of softmax_with_cross_entropy should be a 2-D tensor.");
112-
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
113-
"The labels should be a 2-D tensor.");
112+
rank, labels_dims.size(),
113+
"Input(logits) and Input(Label) shall have the same rank.");
114+
bool check = ctx->IsRuntime() || (framework::product(logits_dims) > 0 &&
115+
framework::product(labels_dims) > 0);
116+
if (check) {
117+
PADDLE_ENFORCE_EQ(framework::slice_ddim(logits_dims, 0, rank - 1),
118+
framework::slice_ddim(labels_dims, 0, rank - 1),
119+
"Input(X) and Input(Label) shall have the same shape "
120+
"except the last dimension.");
121+
}
114122

115123
if (ctx->Attrs().Get<bool>("soft_label")) {
116-
PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
117-
"If Attr(soft_label) == true, the 2nd dimension of "
118-
"Input(X) and Input(Label) should be equal.");
124+
if (check) {
125+
PADDLE_ENFORCE_EQ(logits_dims[rank - 1], labels_dims[rank - 1],
126+
"If Attr(soft_label) == true, the last dimension of "
127+
"Input(X) and Input(Label) should be equal.");
128+
}
119129
} else {
120-
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
121-
"If Attr(soft_label) == false, the 2nd dimension of "
130+
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
131+
"If Attr(softLabel) == false, the last dimension of "
122132
"Input(Label) should be 1.");
123133
}
124134

125135
ctx->SetOutputDim("Softmax", logits_dims);
126-
ctx->SetOutputDim("Loss", {logits_dims[0], 1});
136+
auto loss_dims = logits_dims;
137+
loss_dims[rank - 1] = 1;
138+
ctx->SetOutputDim("Loss", loss_dims);
127139

128140
ctx->ShareLoD("Logits", /*->*/ "Softmax");
129141
ctx->ShareLoD("Logits", /*->*/ "Loss");
@@ -152,16 +164,33 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
152164

153165
auto softmax_dims = ctx->GetInputDim("Softmax");
154166
auto labels_dims = ctx->GetInputDim("Label");
155-
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
156-
"The labels should be a 2-D tensor.");
167+
168+
int rank = softmax_dims.size();
169+
PADDLE_ENFORCE_EQ(
170+
rank, labels_dims.size(),
171+
"Input(logits) and Input(Label) shall have the same rank.");
172+
bool check = true;
173+
if ((!ctx->IsRuntime()) && (framework::product(softmax_dims) <= 0 ||
174+
framework::product(labels_dims) <= 0)) {
175+
check = false;
176+
}
177+
if (check) {
178+
PADDLE_ENFORCE_EQ(
179+
framework::slice_ddim(softmax_dims, 0, rank - 1),
180+
framework::slice_ddim(labels_dims, 0, rank - 1),
181+
"Input(Softmax) and Input(Label) shall have the same shape "
182+
"except the last dimension.");
183+
}
157184

158185
if (ctx->Attrs().Get<bool>("soft_label")) {
159-
PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
160-
"When Attr(soft_label) == true, the 2nd dimension of "
161-
"Input(X) and Input(Label) should be equal.");
186+
if (check) {
187+
PADDLE_ENFORCE_EQ(softmax_dims[rank - 1], labels_dims[rank - 1],
188+
"If Attr(soft_label) == true, the last dimension of "
189+
"Input( Softmax) and Input(Label) should be equal.");
190+
}
162191
} else {
163-
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
164-
"When Attr(soft_label) == false, the 2nd dimension of "
192+
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
193+
"If Attr(softLabel) == false, the last dimension of "
165194
"Input(Label) should be 1.");
166195
}
167196

paddle/fluid/operators/softmax_with_cross_entropy_op.cu

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -400,24 +400,39 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
400400

401401
auto soft_label = context.Attr<bool>("soft_label");
402402
auto ignore_index = context.Attr<int>("ignore_index");
403+
404+
int rank = logits->dims().size();
403405
if (soft_label) {
404-
int batch_size = logits->dims()[0];
405-
int feature_size = logits->dims()[1];
406+
int batch_size = 1;
407+
for (int i = 0; i < rank - 1; ++i) {
408+
batch_size *= logits->dims()[i];
409+
}
410+
411+
int feature_size = logits->dims()[rank - 1];
406412
auto* logits_data = logits->data<T>();
407413
auto* labels_data = labels->data<T>();
408414
SoftmaxWithCrossEntropyFusedKernel(
409415
logits_data, labels_data, softmax_data, loss_data, batch_size,
410416
feature_size, context.cuda_device_context().stream());
411417
} else {
412418
if (!context.Attr<bool>("numeric_stable_mode")) {
413-
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(), logits,
414-
softmax);
419+
// reshape to 2d
420+
Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1);
421+
Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1);
422+
Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1);
423+
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
424+
425+
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
426+
&logits_2d, &softmax_2d);
415427
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
416-
context.cuda_device_context(), loss, softmax, labels, false,
417-
ignore_index);
428+
context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
429+
false, ignore_index);
418430
} else {
419-
int batch_size = logits->dims()[0];
420-
int feature_size = logits->dims()[1];
431+
int batch_size = 1;
432+
for (int i = 0; i < rank - 1; ++i) {
433+
batch_size *= logits->dims()[i];
434+
}
435+
int feature_size = logits->dims()[rank - 1];
421436
auto* logits_data = logits->data<T>();
422437
auto* labels_data = labels->data<int64_t>();
423438
HardLabelSoftmaxWithCrossEntropy<T>(
@@ -443,8 +458,13 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
443458
context.device_context(), logit_grad);
444459
T* logit_grad_data = logit_grad->data<T>();
445460

446-
const int batch_size = logit_grad->dims()[0];
447-
const int class_num = logit_grad->dims()[1];
461+
int rank = logit_grad->dims().size();
462+
int batch_size = 1;
463+
for (int i = 0; i < rank - 1; ++i) {
464+
batch_size *= logit_grad->dims()[i];
465+
}
466+
467+
const int class_num = logit_grad->dims()[rank - 1];
448468
int block = 512;
449469
auto stream = context.cuda_device_context().stream();
450470
auto ignore_index = context.Attr<int>("ignore_index");

paddle/fluid/operators/softmax_with_cross_entropy_op.h

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,22 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
4040
softmax->mutable_data<T>(context.GetPlace());
4141
loss->mutable_data<T>(context.GetPlace());
4242

43-
int axis_dim = logits->dims()[logits->dims().size() - 1];
43+
// reshape to 2D tensor
44+
int rank = logits->dims().size();
45+
Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1);
46+
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
47+
Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1);
48+
Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1);
49+
50+
int axis_dim = logits->dims()[rank - 1];
4451

4552
auto& dev_ctx =
4653
context.template device_context<platform::CPUDeviceContext>();
4754
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
48-
dev_ctx, axis_dim, logits, softmax);
55+
dev_ctx, axis_dim, &logits_2d, &softmax_2d);
4956
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
50-
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
51-
context.Attr<int>("ignore_index"));
57+
dev_ctx, &loss_2d, &softmax_2d, &labels_2d,
58+
context.Attr<bool>("soft_label"), context.Attr<int>("ignore_index"));
5259
}
5360
};
5461

@@ -63,13 +70,19 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
6370
context.Output<Tensor>(framework::GradVarName("Logits"));
6471
logit_grad->ShareDataWith(*context.Input<Tensor>("Softmax"));
6572

66-
const int class_num = logit_grad->dims()[1];
67-
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
68-
auto logit_grad_mat = EigenMatrix<T>::From(*logit_grad);
73+
int rank = logit_grad->dims().size();
74+
const int class_num = logit_grad->dims()[rank - 1];
75+
// reshape to 2d
76+
Tensor logit_grad_2d = framework::ReshapeToMatrix(*logit_grad, rank - 1);
77+
Tensor out_grad_2d = framework::ReshapeToMatrix(*out_grad, rank - 1);
78+
79+
auto out_grad_mat = EigenMatrix<T>::From(out_grad_2d);
80+
auto logit_grad_mat = EigenMatrix<T>::From(logit_grad_2d);
6981
auto& place = *context.template device_context<platform::CPUDeviceContext>()
7082
.eigen_device();
7183
if (context.Attr<bool>("soft_label")) {
72-
auto lbl_mat = EigenMatrix<T>::From(*labels);
84+
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
85+
auto lbl_mat = EigenMatrix<T>::From(labels_2d);
7386
logit_grad_mat.device(place) =
7487
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) *
7588
(logit_grad_mat - lbl_mat);
@@ -78,7 +91,8 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
7891
logit_grad_mat *
7992
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num));
8093

81-
const int batch_size = logit_grad->dims()[0];
94+
const int batch_size = logit_grad_2d.dims()[0];
95+
8296
const int64_t* label_data = labels->data<int64_t>();
8397
T* logit_grad_data = logit_grad->data<T>();
8498
const T* out_grad_data = out_grad->data<T>();

python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,5 +195,144 @@ def initParams(self):
195195
self.numeric_stable_mode = True
196196

197197

198+
class TestSoftmaxWithCrossEntropyOp5(OpTest):
199+
"""
200+
Test softmax with cross entropy operator with ignore_index.
201+
"""
202+
203+
def initParams(self):
204+
self.numeric_stable_mode = False
205+
206+
def setUp(self):
207+
self.initParams()
208+
self.op_type = "softmax_with_cross_entropy"
209+
batch_size = [6, 10]
210+
class_num = 47
211+
212+
logits = np.random.uniform(
213+
0.1, 1.0, tuple(batch_size + [class_num])).astype("float64")
214+
softmax = np.apply_along_axis(stable_softmax, 2, logits)
215+
labels = np.random.randint(
216+
0, class_num, tuple(batch_size + [1]), dtype="int64")
217+
ignore_index = 7
218+
219+
softmax_2d = np.reshape(softmax, [-1, class_num])
220+
labels_2d = np.reshape(labels, [-1, 1])
221+
cross_entropy = np.asmatrix(
222+
[[-np.log(softmax_2d[i][labels_2d[i][0]])]
223+
if labels_2d[i] != ignore_index else [0]
224+
for i in range(softmax_2d.shape[0])],
225+
dtype="float64")
226+
227+
cross_entropy = np.reshape(cross_entropy, batch_size)
228+
229+
output_shape = tuple(batch_size + [1])
230+
output_res = cross_entropy.astype("float64")
231+
output_res = np.expand_dims(output_res, axis=2)
232+
self.inputs = {"Logits": logits, "Label": labels}
233+
self.outputs = {
234+
"Softmax": softmax.astype("float64"),
235+
"Loss": output_res,
236+
}
237+
self.attrs = {
238+
"ignore_index": ignore_index,
239+
"numeric_stable_mode": self.numeric_stable_mode
240+
}
241+
242+
def test_check_output(self):
243+
self.check_output()
244+
245+
def test_check_grad(self):
246+
self.check_grad(["Logits"], "Loss")
247+
248+
249+
class TestSoftmaxWithCrossEntropyOp5NoCudnn(TestSoftmaxWithCrossEntropyOp5):
250+
def initParams(self):
251+
self.numeric_stable_mode = True
252+
253+
254+
class TestSoftmaxWithCrossEntropyOp6(OpTest):
255+
"""
256+
Test softmax with cross entropy operator with soft labels.
257+
"""
258+
259+
def setUp(self):
260+
self.op_type = "softmax_with_cross_entropy"
261+
batch_size = [6, 10]
262+
class_num = 37
263+
264+
logits = np.random.uniform(
265+
0.1, 1.0, tuple(batch_size + [class_num])).astype("float64")
266+
softmax = np.apply_along_axis(stable_softmax, 2, logits)
267+
labels = np.random.uniform(
268+
0.1, 1.0, tuple(batch_size + [class_num])).astype("float64")
269+
labels /= np.sum(labels, axis=2, keepdims=True)
270+
271+
cross_entropy = (-labels * np.log(softmax)).sum(
272+
axis=2, keepdims=True).astype("float64")
273+
274+
self.inputs = {"Logits": logits, "Label": labels}
275+
self.outputs = {
276+
"Softmax": softmax.astype("float64"),
277+
"Loss": cross_entropy.astype("float64")
278+
}
279+
self.attrs = {"soft_label": True}
280+
281+
def test_check_output(self):
282+
self.check_output()
283+
284+
def test_check_grad(self):
285+
self.check_grad(["Logits"], "Loss")
286+
287+
288+
class TestSoftmaxWithCrossEntropyOpFp16_2(TestSoftmaxWithCrossEntropyOp):
289+
def initParams(self):
290+
self.numeric_stable_mode = False
291+
self.dtype = np.float16
292+
293+
def setUp(self):
294+
self.initParams()
295+
self.op_type = "softmax_with_cross_entropy"
296+
batch_size = [64, 10]
297+
class_num = 37
298+
299+
# NOTE: numpy float16 have very low accuracy, use float32 for numpy check.
300+
logits = np.random.uniform(
301+
0.1, 1.0, tuple(batch_size + [class_num])).astype(np.float32)
302+
softmax = np.apply_along_axis(stable_softmax, 2, logits)
303+
labels = np.random.randint(
304+
0, class_num, tuple(batch_size + [1]), dtype="int64")
305+
306+
softmax_2d = np.reshape(softmax, [-1, class_num])
307+
labels_2d = np.reshape(labels, [-1, 1])
308+
309+
cross_entropy = np.asmatrix(
310+
[[-np.log(softmax_2d[i][labels_2d[i][0]])]
311+
for i in range(softmax_2d.shape[0])],
312+
dtype=np.float32)
313+
314+
cross_entropy = np.reshape(cross_entropy, batch_size)
315+
output_shape = tuple(batch_size + [1])
316+
output_res = cross_entropy.astype(self.dtype)
317+
output_res = np.expand_dims(output_res, axis=2)
318+
self.inputs = {"Logits": logits, "Label": labels}
319+
320+
self.inputs = {
321+
"Logits": logits.astype(self.dtype).view(np.uint16),
322+
"Label": labels
323+
}
324+
self.outputs = {
325+
"Softmax": softmax.astype(self.dtype),
326+
"Loss": output_res,
327+
}
328+
self.attrs = {"numeric_stable_mode": self.numeric_stable_mode}
329+
330+
def test_check_output(self):
331+
self.check_output(atol=1e-2)
332+
333+
def test_check_grad(self):
334+
self.check_grad(["Logits"], "Loss", max_relative_error=0.1)
335+
336+
198337
if __name__ == "__main__":
199338
unittest.main()

0 commit comments

Comments
 (0)