Skip to content

Commit b5a16dc

Browse files
authored
Fix a critical bug in softmax_with_cross_entropy_op backward. (#9120)
* Fix a critical bug in softmax_with_cross_entropy_op, which will lead to the wrong gradients. * Enhance unit testing.
1 parent 1e4c504 commit b5a16dc

File tree

2 files changed

+26
-26
lines changed

2 files changed

+26
-26
lines changed

paddle/fluid/operators/softmax_with_cross_entropy_op.cu

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,21 @@ using Tensor = framework::Tensor;
2323

2424
namespace {
2525
template <typename T>
26-
__global__ void CrossEntropyGrad(T* logit_grad, const T* loss_grad,
27-
const int64_t* labels, const int batch_size,
28-
const int class_num) {
29-
int tid = blockIdx.x * blockDim.x + threadIdx.x;
30-
int sample_idx = tid / class_num;
31-
32-
if (tid < batch_size) {
33-
PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
34-
logit_grad[tid * class_num + labels[tid]] -= static_cast<T>(1.);
26+
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
27+
const int batch_size, const int class_num) {
28+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size;
29+
i += blockDim.x * gridDim.x) {
30+
int idx = i * class_num + labels[i];
31+
logit_grad[idx] -= static_cast<T>(1.);
3532
}
33+
}
3634

37-
__syncthreads();
38-
39-
if (tid < batch_size * class_num) {
40-
logit_grad[tid] *= loss_grad[sample_idx];
35+
template <typename T>
36+
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
37+
const int class_num) {
38+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
39+
i += blockDim.x * gridDim.x) {
40+
logit_grad[i] *= loss_grad[i / class_num];
4141
}
4242
}
4343

@@ -94,22 +94,22 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
9494
const int batch_size = logit_grad->dims()[0];
9595
const int class_num = logit_grad->dims()[1];
9696
int block = 512;
97-
int grid = (batch_size * class_num + block - 1) / block;
97+
auto stream = context.cuda_device_context().stream();
9898

9999
if (context.Attr<bool>("soft_label")) {
100+
int grid = (batch_size * class_num + block - 1) / block;
100101
const T* label_data = labels->data<T>();
101-
SoftCrossEntropyGradientKernel<
102-
T><<<grid, block, 0,
103-
context.template device_context<platform::CUDADeviceContext>()
104-
.stream()>>>(logit_grad_data, loss_grad_data, label_data,
105-
batch_size, class_num);
102+
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
103+
logit_grad_data, loss_grad_data, label_data, batch_size, class_num);
106104
} else {
105+
int grid = (batch_size + block - 1) / block;
107106
const int64_t* label_data = labels->data<int64_t>();
108-
CrossEntropyGrad<
109-
T><<<grid, block, 0,
110-
context.template device_context<platform::CUDADeviceContext>()
111-
.stream()>>>(logit_grad_data, loss_grad_data, label_data,
112-
batch_size, class_num);
107+
CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
108+
logit_grad_data, label_data, batch_size, class_num);
109+
int num = batch_size * class_num;
110+
grid = (num + block - 1) / block;
111+
Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
112+
class_num);
113113
}
114114
}
115115
};

python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
2626

2727
def setUp(self):
2828
self.op_type = "softmax_with_cross_entropy"
29-
batch_size = 2
29+
batch_size = 41
3030
class_num = 37
3131

3232
logits = np.random.uniform(0.1, 1.0,
@@ -59,7 +59,7 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest):
5959

6060
def setUp(self):
6161
self.op_type = "softmax_with_cross_entropy"
62-
batch_size = 2
62+
batch_size = 41
6363
class_num = 37
6464

6565
logits = np.random.uniform(0.1, 1.0,

0 commit comments

Comments
 (0)