@@ -23,21 +23,21 @@ using Tensor = framework::Tensor;
23
23
24
24
namespace {
25
25
template <typename T>
26
- __global__ void CrossEntropyGrad (T* logit_grad, const T* loss_grad,
27
- const int64_t * labels, const int batch_size,
28
- const int class_num) {
29
- int tid = blockIdx .x * blockDim .x + threadIdx .x ;
30
- int sample_idx = tid / class_num;
31
-
32
- if (tid < batch_size) {
33
- PADDLE_ASSERT (labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
34
- logit_grad[tid * class_num + labels[tid]] -= static_cast <T>(1 .);
26
+ __global__ void CrossEntropyGrad (T* logit_grad, const int64_t * labels,
27
+ const int batch_size, const int class_num) {
28
+ for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < batch_size;
29
+ i += blockDim .x * gridDim .x ) {
30
+ int idx = i * class_num + labels[i];
31
+ logit_grad[idx] -= static_cast <T>(1 .);
35
32
}
33
+ }
36
34
37
- __syncthreads ();
38
-
39
- if (tid < batch_size * class_num) {
40
- logit_grad[tid] *= loss_grad[sample_idx];
35
+ template <typename T>
36
+ __global__ void Scale (T* logit_grad, const T* loss_grad, const int num,
37
+ const int class_num) {
38
+ for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < num;
39
+ i += blockDim .x * gridDim .x ) {
40
+ logit_grad[i] *= loss_grad[i / class_num];
41
41
}
42
42
}
43
43
@@ -94,22 +94,22 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
94
94
const int batch_size = logit_grad->dims ()[0 ];
95
95
const int class_num = logit_grad->dims ()[1 ];
96
96
int block = 512 ;
97
- int grid = (batch_size * class_num + block - 1 ) / block ;
97
+ auto stream = context. cuda_device_context (). stream () ;
98
98
99
99
if (context.Attr <bool >(" soft_label" )) {
100
+ int grid = (batch_size * class_num + block - 1 ) / block;
100
101
const T* label_data = labels->data <T>();
101
- SoftCrossEntropyGradientKernel<
102
- T><<<grid, block, 0 ,
103
- context.template device_context<platform::CUDADeviceContext>()
104
- .stream()>>> (logit_grad_data, loss_grad_data, label_data,
105
- batch_size, class_num);
102
+ SoftCrossEntropyGradientKernel<T><<<grid, block, 0 , stream>>> (
103
+ logit_grad_data, loss_grad_data, label_data, batch_size, class_num);
106
104
} else {
105
+ int grid = (batch_size + block - 1 ) / block;
107
106
const int64_t * label_data = labels->data <int64_t >();
108
- CrossEntropyGrad<
109
- T><<<grid, block, 0 ,
110
- context.template device_context<platform::CUDADeviceContext>()
111
- .stream()>>> (logit_grad_data, loss_grad_data, label_data,
112
- batch_size, class_num);
107
+ CrossEntropyGrad<T><<<grid, block, 0 , stream>>> (
108
+ logit_grad_data, label_data, batch_size, class_num);
109
+ int num = batch_size * class_num;
110
+ grid = (num + block - 1 ) / block;
111
+ Scale<T><<<grid, block, 0 , stream>>> (logit_grad_data, loss_grad_data, num,
112
+ class_num);
113
113
}
114
114
}
115
115
};
0 commit comments