1
- /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1
+ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2
2
3
3
Licensed under the Apache License, Version 2.0 (the "License");
4
4
you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@ limitations under the License. */
14
14
15
15
#define EIGEN_USE_GPU
16
16
17
+ #include < cub/cub.cuh>
18
+ #include " paddle/fluid/operators/math/cross_entropy.h"
17
19
#include " paddle/fluid/operators/softmax_with_cross_entropy_op.h"
18
20
19
21
namespace paddle {
@@ -53,8 +55,196 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
53
55
logit_grad[ids] = loss_grad[row_ids] * (logit_grad[ids] - labels[ids]);
54
56
}
55
57
}
58
+
56
59
} // namespace
57
60
61
+ static __device__ __forceinline__ float real_exp (float x) { return expf (x); }
62
+ static __device__ __forceinline__ double real_exp (double x) { return exp (x); }
63
+ static __device__ __forceinline__ float real_log (float x) {
64
+ return math::TolerableValue<float >()(logf (x));
65
+ }
66
+ static __device__ __forceinline__ double real_log (double x) {
67
+ return math::TolerableValue<double >()(log (x));
68
+ }
69
+
70
+ /* * In the following codes, 3 CUDA kernels are implemented to calculate softmax
71
+ * and loss **/
72
+ /*
73
+ Supposing the x is `logits` and y is `labels`, the equations are as
74
+ followings:
75
+
76
+ cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
77
+ = \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
78
+ = \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
79
+ = \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
80
+ = \sum_{j}(-y_i_j * tmp_i_j)
81
+
82
+ softmax_i_j = e^{tmp_i_j}
83
+
84
+ where:
85
+ max_i = \max_{j}{x_i_j}
86
+ logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
87
+ tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
88
+
89
+ Therefore, the calculation can be separated into 3 steps:
90
+ Step 1: row-wise operation to calculate max_i
91
+ Step 2: row-wise operation to calculate logDiffMaxSum_i
92
+ Step 3: caculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
93
+
94
+ To save memory, we can share memory among max_i, logDiffMaxSum_i and
95
+ cross\_entropy_i.
96
+ In this way, the 3 steps should be changed to:
97
+ Step 1 (RowReductionForMax): row-wise operation to calculate max_i
98
+ Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
99
+ x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
100
+ Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
101
+ - logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
102
+ */
103
+
104
+ // There are 3 kinds of reduce algorithms in cub:
105
+ // BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
106
+ // BLOCK_REDUCE_RAKING
107
+ // BLOCK_REDUCE_WARP_REDUCTIONS (default)
108
+ template <typename T, int BlockDim>
109
+ using BlockReduce =
110
+ cub::BlockReduce<T, BlockDim /* , cub::BLOCK_REDUCE_WARP_REDUCTIONS*/ >;
111
+
112
+ template <typename T, int BlockDim>
113
+ using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
114
+
115
+ // Make sure that BlockDim <= feature_size
116
+ // This kernel is used to calculate the max element of each row
117
+ template <typename T, int BlockDim>
118
+ __global__ void RowReductionForMax (const T* logits_data, T* max_data,
119
+ int feature_size) {
120
+ __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
121
+
122
+ auto beg_idx = feature_size * blockIdx .x + threadIdx .x ;
123
+ auto end_idx = feature_size * (blockIdx .x + 1 );
124
+
125
+ T cur_max = logits_data[beg_idx];
126
+ beg_idx += BlockDim;
127
+ while (beg_idx < end_idx) {
128
+ if (cur_max < logits_data[beg_idx]) {
129
+ cur_max = logits_data[beg_idx];
130
+ }
131
+ beg_idx += BlockDim;
132
+ }
133
+
134
+ cur_max = BlockReduce<T, BlockDim>(temp_storage).Reduce (cur_max, cub::Max ());
135
+
136
+ if (threadIdx .x == 0 ) {
137
+ max_data[blockIdx .x ] = cur_max < -64 ? -64 : cur_max;
138
+ }
139
+ }
140
+
141
+ // Make sure that BlockDim <= feature_size
142
+ template <typename T, int BlockDim>
143
+ __global__ void RowReductionForDiffMaxSum (const T* logits_data, T* max_data,
144
+ T* softmax, int feature_size) {
145
+ __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
146
+
147
+ auto beg_idx = feature_size * blockIdx .x + threadIdx .x ;
148
+ auto end_idx = feature_size * (blockIdx .x + 1 );
149
+
150
+ auto block_max = max_data[blockIdx .x ];
151
+
152
+ softmax[beg_idx] = logits_data[beg_idx] - block_max;
153
+ T diff_max_sum = real_exp (softmax[beg_idx]);
154
+ beg_idx += BlockDim;
155
+ while (beg_idx < end_idx) {
156
+ softmax[beg_idx] = logits_data[beg_idx] - block_max;
157
+ diff_max_sum += real_exp (softmax[beg_idx]);
158
+ beg_idx += BlockDim;
159
+ }
160
+
161
+ diff_max_sum =
162
+ BlockReduce<T, BlockDim>(temp_storage).Reduce (diff_max_sum, cub::Sum ());
163
+ if (threadIdx .x == 0 ) max_data[blockIdx .x ] = real_log (diff_max_sum);
164
+ }
165
+
166
+ // Make sure that BlockDim <= feature_size
167
+ template <typename T, int BlockDim>
168
+ __global__ void RowReductionForSoftmaxAndCrossEntropy (const T* logits_data,
169
+ const T* labels_data,
170
+ T* loss_data, T* softmax,
171
+ int feature_size) {
172
+ __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
173
+
174
+ auto beg_idx = feature_size * blockIdx .x + threadIdx .x ;
175
+ auto end_idx = feature_size * (blockIdx .x + 1 );
176
+
177
+ // log_diff_max_sum shares memory with loss
178
+ auto block_log_diff_max_sum = loss_data[blockIdx .x ];
179
+ auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
180
+ softmax[beg_idx] = real_exp (tmp);
181
+ auto loss = -labels_data[beg_idx] * tmp;
182
+ beg_idx += BlockDim;
183
+ while (beg_idx < end_idx) {
184
+ tmp = softmax[beg_idx] - block_log_diff_max_sum;
185
+ softmax[beg_idx] = real_exp (tmp);
186
+ loss -= (labels_data[beg_idx] * tmp);
187
+ beg_idx += BlockDim;
188
+ }
189
+
190
+ loss = BlockReduce<T, BlockDim>(temp_storage).Reduce (loss, cub::Sum ());
191
+ if (threadIdx .x == 0 ) loss_data[blockIdx .x ] = loss;
192
+ }
193
+
194
+ template <typename T>
195
+ __global__ void SetSoftmaxToOneWhenFeatureSizeIsOne (T* out, int batch_size) {
196
+ auto idx = threadIdx .x + blockIdx .x * blockDim .x ;
197
+ if (idx < batch_size) out[idx] = static_cast <T>(1 );
198
+ }
199
+
200
+ template <typename T>
201
+ static void SoftmaxWithCrossEntropyFusedKernel (const T* logits_data,
202
+ const T* labels_data,
203
+ T* softmax_data, T* loss_data,
204
+ int batch_size, int feature_size,
205
+ cudaStream_t stream) {
206
+ constexpr int kMaxBlockDim = 512 ;
207
+ int block_dim = feature_size >= kMaxBlockDim
208
+ ? kMaxBlockDim
209
+ : (1 << static_cast <int >(std::log2 (feature_size)));
210
+
211
+ #define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL (BlockDim ) \
212
+ case BlockDim: \
213
+ RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0 , stream>>> ( \
214
+ logits_data, loss_data, feature_size); \
215
+ RowReductionForDiffMaxSum<T, \
216
+ BlockDim><<<batch_size, BlockDim, 0 , stream>>> ( \
217
+ logits_data, loss_data, softmax_data, feature_size); \
218
+ RowReductionForSoftmaxAndCrossEntropy< \
219
+ T, BlockDim><<<batch_size, BlockDim, 0 , stream>>> ( \
220
+ logits_data, labels_data, loss_data, softmax_data, feature_size); \
221
+ break
222
+
223
+ switch (block_dim) {
224
+ CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL (512 );
225
+ CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL (256 );
226
+ CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL (128 );
227
+ CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL (64 );
228
+ CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL (32 );
229
+ CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL (16 );
230
+ CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL (8 );
231
+ CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL (4 );
232
+ CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL (2 );
233
+ case 1 :
234
+ SetSoftmaxToOneWhenFeatureSizeIsOne<<<(batch_size + kMaxBlockDim - 1 ) /
235
+ kMaxBlockDim ,
236
+ kMaxBlockDim , 0 , stream>>> (
237
+ softmax_data, batch_size);
238
+ cudaMemsetAsync (loss_data, 0 , batch_size, stream);
239
+ break ;
240
+ default :
241
+ PADDLE_THROW (" BlockDim must be 2^n in softmax_with_cross_entropy_op" );
242
+ break ;
243
+ }
244
+
245
+ #undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
246
+ }
247
+
58
248
template <typename T>
59
249
class SoftmaxWithCrossEntropyCUDAKernel : public framework ::OpKernel<T> {
60
250
public:
@@ -66,14 +256,24 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
66
256
Tensor* softmax = context.Output <Tensor>(" Softmax" );
67
257
68
258
Tensor* loss = context.Output <Tensor>(" Loss" );
69
- softmax->mutable_data <T>(context.GetPlace ());
70
- loss->mutable_data <T>(context.GetPlace ());
71
-
72
- math::SoftmaxFunctor<platform::CUDADeviceContext, T>()(
73
- context.cuda_device_context (), logits, softmax);
74
- math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
75
- context.cuda_device_context (), loss, softmax, labels,
76
- context.Attr <bool >(" soft_label" ));
259
+ auto * softmax_data = softmax->mutable_data <T>(context.GetPlace ());
260
+ auto * loss_data = loss->mutable_data <T>(context.GetPlace ());
261
+
262
+ auto soft_label = context.Attr <bool >(" soft_label" );
263
+ if (soft_label) {
264
+ int batch_size = logits->dims ()[0 ];
265
+ int feature_size = logits->dims ()[1 ];
266
+ auto * logits_data = logits->data <T>();
267
+ auto * labels_data = labels->data <T>();
268
+ SoftmaxWithCrossEntropyFusedKernel (
269
+ logits_data, labels_data, softmax_data, loss_data, batch_size,
270
+ feature_size, context.cuda_device_context ().stream ());
271
+ } else {
272
+ math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context (), logits,
273
+ softmax);
274
+ math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
275
+ context.cuda_device_context (), loss, softmax, labels, false );
276
+ }
77
277
}
78
278
};
79
279
0 commit comments