Skip to content

Commit cf799a6

Browse files
authored
Merge pull request #12553 from sneaxiy/refine_softmax_with_cross_entropy
Refine softmax_with_cross_entropy op
2 parents 772ceee + 1b4515f commit cf799a6

File tree

1 file changed

+209
-9
lines changed

1 file changed

+209
-9
lines changed

paddle/fluid/operators/softmax_with_cross_entropy_op.cu

Lines changed: 209 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@ limitations under the License. */
1414

1515
#define EIGEN_USE_GPU
1616

17+
#include <cub/cub.cuh>
18+
#include "paddle/fluid/operators/math/cross_entropy.h"
1719
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
1820

1921
namespace paddle {
@@ -53,8 +55,196 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
5355
logit_grad[ids] = loss_grad[row_ids] * (logit_grad[ids] - labels[ids]);
5456
}
5557
}
58+
5659
} // namespace
5760

61+
static __device__ __forceinline__ float real_exp(float x) { return expf(x); }
62+
static __device__ __forceinline__ double real_exp(double x) { return exp(x); }
63+
static __device__ __forceinline__ float real_log(float x) {
64+
return math::TolerableValue<float>()(logf(x));
65+
}
66+
static __device__ __forceinline__ double real_log(double x) {
67+
return math::TolerableValue<double>()(log(x));
68+
}
69+
70+
/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
71+
* and loss **/
72+
/*
73+
Supposing the x is `logits` and y is `labels`, the equations are as
74+
followings:
75+
76+
cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
77+
= \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
78+
= \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
79+
= \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
80+
= \sum_{j}(-y_i_j * tmp_i_j)
81+
82+
softmax_i_j = e^{tmp_i_j}
83+
84+
where:
85+
max_i = \max_{j}{x_i_j}
86+
logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
87+
tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
88+
89+
Therefore, the calculation can be separated into 3 steps:
90+
Step 1: row-wise operation to calculate max_i
91+
Step 2: row-wise operation to calculate logDiffMaxSum_i
92+
Step 3: caculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
93+
94+
To save memory, we can share memory among max_i, logDiffMaxSum_i and
95+
cross\_entropy_i.
96+
In this way, the 3 steps should be changed to:
97+
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
98+
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
99+
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
100+
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
101+
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
102+
*/
103+
104+
// There are 3 kinds of reduce algorithms in cub:
105+
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
106+
// BLOCK_REDUCE_RAKING
107+
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
108+
template <typename T, int BlockDim>
109+
using BlockReduce =
110+
cub::BlockReduce<T, BlockDim /*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/>;
111+
112+
template <typename T, int BlockDim>
113+
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
114+
115+
// Make sure that BlockDim <= feature_size
116+
// This kernel is used to calculate the max element of each row
117+
template <typename T, int BlockDim>
118+
__global__ void RowReductionForMax(const T* logits_data, T* max_data,
119+
int feature_size) {
120+
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
121+
122+
auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
123+
auto end_idx = feature_size * (blockIdx.x + 1);
124+
125+
T cur_max = logits_data[beg_idx];
126+
beg_idx += BlockDim;
127+
while (beg_idx < end_idx) {
128+
if (cur_max < logits_data[beg_idx]) {
129+
cur_max = logits_data[beg_idx];
130+
}
131+
beg_idx += BlockDim;
132+
}
133+
134+
cur_max = BlockReduce<T, BlockDim>(temp_storage).Reduce(cur_max, cub::Max());
135+
136+
if (threadIdx.x == 0) {
137+
max_data[blockIdx.x] = cur_max < -64 ? -64 : cur_max;
138+
}
139+
}
140+
141+
// Make sure that BlockDim <= feature_size
142+
template <typename T, int BlockDim>
143+
__global__ void RowReductionForDiffMaxSum(const T* logits_data, T* max_data,
144+
T* softmax, int feature_size) {
145+
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
146+
147+
auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
148+
auto end_idx = feature_size * (blockIdx.x + 1);
149+
150+
auto block_max = max_data[blockIdx.x];
151+
152+
softmax[beg_idx] = logits_data[beg_idx] - block_max;
153+
T diff_max_sum = real_exp(softmax[beg_idx]);
154+
beg_idx += BlockDim;
155+
while (beg_idx < end_idx) {
156+
softmax[beg_idx] = logits_data[beg_idx] - block_max;
157+
diff_max_sum += real_exp(softmax[beg_idx]);
158+
beg_idx += BlockDim;
159+
}
160+
161+
diff_max_sum =
162+
BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
163+
if (threadIdx.x == 0) max_data[blockIdx.x] = real_log(diff_max_sum);
164+
}
165+
166+
// Make sure that BlockDim <= feature_size
167+
template <typename T, int BlockDim>
168+
__global__ void RowReductionForSoftmaxAndCrossEntropy(const T* logits_data,
169+
const T* labels_data,
170+
T* loss_data, T* softmax,
171+
int feature_size) {
172+
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
173+
174+
auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
175+
auto end_idx = feature_size * (blockIdx.x + 1);
176+
177+
// log_diff_max_sum shares memory with loss
178+
auto block_log_diff_max_sum = loss_data[blockIdx.x];
179+
auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
180+
softmax[beg_idx] = real_exp(tmp);
181+
auto loss = -labels_data[beg_idx] * tmp;
182+
beg_idx += BlockDim;
183+
while (beg_idx < end_idx) {
184+
tmp = softmax[beg_idx] - block_log_diff_max_sum;
185+
softmax[beg_idx] = real_exp(tmp);
186+
loss -= (labels_data[beg_idx] * tmp);
187+
beg_idx += BlockDim;
188+
}
189+
190+
loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
191+
if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
192+
}
193+
194+
template <typename T>
195+
__global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out, int batch_size) {
196+
auto idx = threadIdx.x + blockIdx.x * blockDim.x;
197+
if (idx < batch_size) out[idx] = static_cast<T>(1);
198+
}
199+
200+
template <typename T>
201+
static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
202+
const T* labels_data,
203+
T* softmax_data, T* loss_data,
204+
int batch_size, int feature_size,
205+
cudaStream_t stream) {
206+
constexpr int kMaxBlockDim = 512;
207+
int block_dim = feature_size >= kMaxBlockDim
208+
? kMaxBlockDim
209+
: (1 << static_cast<int>(std::log2(feature_size)));
210+
211+
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
212+
case BlockDim: \
213+
RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
214+
logits_data, loss_data, feature_size); \
215+
RowReductionForDiffMaxSum<T, \
216+
BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
217+
logits_data, loss_data, softmax_data, feature_size); \
218+
RowReductionForSoftmaxAndCrossEntropy< \
219+
T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
220+
logits_data, labels_data, loss_data, softmax_data, feature_size); \
221+
break
222+
223+
switch (block_dim) {
224+
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
225+
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
226+
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
227+
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
228+
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
229+
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
230+
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
231+
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
232+
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
233+
case 1:
234+
SetSoftmaxToOneWhenFeatureSizeIsOne<<<(batch_size + kMaxBlockDim - 1) /
235+
kMaxBlockDim,
236+
kMaxBlockDim, 0, stream>>>(
237+
softmax_data, batch_size);
238+
cudaMemsetAsync(loss_data, 0, batch_size, stream);
239+
break;
240+
default:
241+
PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
242+
break;
243+
}
244+
245+
#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
246+
}
247+
58248
template <typename T>
59249
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
60250
public:
@@ -66,14 +256,24 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
66256
Tensor* softmax = context.Output<Tensor>("Softmax");
67257

68258
Tensor* loss = context.Output<Tensor>("Loss");
69-
softmax->mutable_data<T>(context.GetPlace());
70-
loss->mutable_data<T>(context.GetPlace());
71-
72-
math::SoftmaxFunctor<platform::CUDADeviceContext, T>()(
73-
context.cuda_device_context(), logits, softmax);
74-
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
75-
context.cuda_device_context(), loss, softmax, labels,
76-
context.Attr<bool>("soft_label"));
259+
auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
260+
auto* loss_data = loss->mutable_data<T>(context.GetPlace());
261+
262+
auto soft_label = context.Attr<bool>("soft_label");
263+
if (soft_label) {
264+
int batch_size = logits->dims()[0];
265+
int feature_size = logits->dims()[1];
266+
auto* logits_data = logits->data<T>();
267+
auto* labels_data = labels->data<T>();
268+
SoftmaxWithCrossEntropyFusedKernel(
269+
logits_data, labels_data, softmax_data, loss_data, batch_size,
270+
feature_size, context.cuda_device_context().stream());
271+
} else {
272+
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(), logits,
273+
softmax);
274+
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
275+
context.cuda_device_context(), loss, softmax, labels, false);
276+
}
77277
}
78278
};
79279

0 commit comments

Comments
 (0)