Skip to content

Commit bc385a2

Browse files
authored
fix softmax_with_cross_entropy_fix bug, test=develop (#21810) (#22183)
1 parent 515b206 commit bc385a2

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

paddle/fluid/operators/softmax_with_cross_entropy_op.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -415,19 +415,19 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
415415
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
416416
int axis_dim = logits->dims()[axis];
417417

418+
const int n = SizeToAxis(axis, logits->dims());
419+
const int d = SizeFromAxis(axis, logits->dims());
420+
421+
auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
422+
auto* loss_data = loss->mutable_data<T>(context.GetPlace());
423+
418424
if (axis_dim == 1) {
419425
math::SetConstant<platform::CUDADeviceContext, T> set_constant;
420426
set_constant(context.cuda_device_context(), softmax, static_cast<T>(1));
421427
set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
422428
return;
423429
}
424430

425-
const int n = SizeToAxis(axis, logits->dims());
426-
const int d = SizeFromAxis(axis, logits->dims());
427-
428-
auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
429-
auto* loss_data = loss->mutable_data<T>(context.GetPlace());
430-
431431
auto soft_label = context.Attr<bool>("soft_label");
432432
auto ignore_index = context.Attr<int>("ignore_index");
433433

python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,23 @@ def initParams(self):
280280
self.shape = [3, 5, 7, 11]
281281

282282

283+
class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne(
284+
TestSoftmaxWithCrossEntropyOp):
285+
"""
286+
Test softmax with cross entropy operator with discreate one-hot labels.
287+
Given axis != -1
288+
"""
289+
290+
def initParams(self):
291+
self.op_type = "softmax_with_cross_entropy"
292+
self.numeric_stable_mode = True
293+
self.soft_label = False
294+
self.dtype = np.float64
295+
self.axis = -1
296+
self.ignore_index = -1
297+
self.shape = [3, 5, 7, 1]
298+
299+
283300
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1(
284301
TestSoftmaxWithCrossEntropyOpNoCudnnFp16):
285302
def initParams(self):

0 commit comments

Comments
 (0)