Skip to content

Commit 6757a31

Browse files
author
chengduo
authored
[Accelerate] Refine seq_softmax_op (#13421)
* refine seq_softmax_op * fix seq_softmax * use cub in seq_softmax
1 parent c686595 commit 6757a31

File tree

7 files changed

+261
-75
lines changed

7 files changed

+261
-75
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_
107107
paddle.fluid.layers.conv2d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, False, None, None))
108108
paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, False, None, None))
109109
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None)
110-
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None, True))
110+
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None, False))
111111
paddle.fluid.layers.softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None))
112112
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'use_mkldnn', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, False, None))
113113
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'use_mkldnn', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, False, None))

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,12 @@ endif()
252252
op_library(cross_entropy_op DEPS cross_entropy)
253253
if(WITH_GPU)
254254
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax cub)
255+
op_library(sequence_softmax_op DEPS cub)
255256
else()
256257
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
257258
endif()
258259

259260
op_library(softmax_op DEPS softmax)
260-
op_library(sequence_softmax_op DEPS softmax)
261261
if (WITH_GPU AND TENSORRT_FOUND)
262262
op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter)
263263
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(tensorrt_engine);\n")

paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ class SequenceSoftmaxCUDNNKernel : public framework::OpKernel<T> {
2929
auto* x = ctx.Input<LoDTensor>("X");
3030
auto* out = ctx.Output<LoDTensor>("Out");
3131

32-
auto lod = x->lod();
33-
auto dims = x->dims();
32+
auto& lod = x->lod();
33+
auto& dims = x->dims();
3434

3535
const size_t level = lod.size() - 1;
3636
PADDLE_ENFORCE_EQ(dims[0], static_cast<int64_t>(lod[level].back()),
@@ -71,7 +71,7 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
7171
if (x_grad) {
7272
x_grad->set_lod(x->lod());
7373
}
74-
auto lod = x->lod();
74+
auto& lod = x->lod();
7575
const size_t level = lod.size() - 1;
7676

7777
x_grad->mutable_data<T>(ctx.GetPlace());
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <algorithm>
16+
#include <cub/cub.cuh> // NOLINT
17+
#include "paddle/fluid/operators/sequence_softmax_op.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using LoDTensor = framework::LoDTensor;
23+
24+
__device__ __forceinline__ float real_exp(float x) { return expf(x); }
25+
__device__ __forceinline__ double real_exp(double x) { return exp(x); }
26+
27+
template <typename T, int BlockDim>
28+
using BlockReduce = cub::BlockReduce<T, BlockDim>;
29+
30+
template <typename T, int BlockDim>
31+
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
32+
33+
template <typename T, int BlockDim>
34+
__global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod,
35+
const size_t src_hight, T *out_data) {
36+
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
37+
__shared__ T shared_max_data;
38+
__shared__ T shared_sum_data;
39+
40+
for (int i = blockIdx.x; i < src_hight; i += gridDim.x) {
41+
size_t start = ref_lod[i];
42+
size_t span = ref_lod[i + 1] - start;
43+
44+
// Find the max ele
45+
T max_ele = -FLT_MAX;
46+
for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
47+
T ele = in_data[start + tid];
48+
max_ele = max_ele > ele ? max_ele : ele;
49+
}
50+
max_ele =
51+
BlockReduce<T, BlockDim>(temp_storage).Reduce(max_ele, cub::Max());
52+
if (threadIdx.x == 0) {
53+
shared_max_data = max_ele;
54+
}
55+
__syncthreads();
56+
57+
// sum
58+
T sum_data = 0;
59+
for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
60+
T ele = in_data[start + tid];
61+
sum_data += real_exp(ele - shared_max_data);
62+
}
63+
sum_data =
64+
BlockReduce<T, BlockDim>(temp_storage).Reduce(sum_data, cub::Sum());
65+
if (threadIdx.x == 0) {
66+
shared_sum_data = sum_data;
67+
}
68+
__syncthreads();
69+
70+
// get final resit
71+
for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
72+
T ele = in_data[start + tid];
73+
ele = real_exp(ele - shared_max_data) / shared_sum_data;
74+
out_data[start + tid] = ele;
75+
}
76+
}
77+
}
78+
79+
template <typename T, int BlockDim>
80+
__global__ void sequence_softmax_grad_kernel(const T *softmax_grad_data,
81+
const T *softmax_data,
82+
const size_t *ref_lod,
83+
const size_t src_hight,
84+
T *dx_data) {
85+
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
86+
__shared__ T shared_data;
87+
88+
for (int i = blockIdx.x; i < src_hight; i += gridDim.x) {
89+
size_t start = ref_lod[i];
90+
size_t span = ref_lod[i + 1] - start;
91+
92+
T result = 0;
93+
for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
94+
size_t idx = start + tid;
95+
T s_g_d = softmax_grad_data[idx];
96+
T s_d = softmax_data[idx];
97+
result += s_g_d * s_d;
98+
}
99+
result = BlockReduce<T, BlockDim>(temp_storage).Reduce(result, cub::Sum());
100+
if (threadIdx.x == 0) {
101+
shared_data = result;
102+
}
103+
__syncthreads();
104+
105+
for (int tid = threadIdx.x; tid < span; tid += blockDim.x) {
106+
size_t idx = start + tid;
107+
T s_g_d = softmax_grad_data[idx];
108+
T s_d = softmax_data[idx];
109+
dx_data[idx] = (s_g_d - shared_data) * s_d;
110+
}
111+
}
112+
}
113+
114+
template <typename T>
115+
struct SequenceSoftmaxFunctor<platform::CUDADeviceContext, T> {
116+
void operator()(const platform::CUDADeviceContext &context,
117+
const LoDTensor &x,
118+
const framework::Vector<size_t> &ref_lod, /*referenced lod*/
119+
LoDTensor *out) {
120+
int hight = ref_lod.size() - 1;
121+
122+
const int kThreadsPerBlock = 32;
123+
int thread_x = kThreadsPerBlock;
124+
int max_threads = context.GetMaxPhysicalThreadCount();
125+
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
126+
127+
dim3 block_size(thread_x);
128+
dim3 grid_size(max_blocks);
129+
sequence_softmax_kernel<
130+
T, kThreadsPerBlock><<<grid_size, block_size, 0, context.stream()>>>(
131+
x.data<T>(), ref_lod.CUDAData(context.GetPlace()), hight,
132+
out->mutable_data<T>(context.GetPlace()));
133+
}
134+
};
135+
136+
template <typename T>
137+
struct SequenceSoftmaxGradFunctor<platform::CUDADeviceContext, T> {
138+
void operator()(const platform::CUDADeviceContext &context,
139+
const LoDTensor &dout, const LoDTensor &out,
140+
const framework::Vector<size_t> &ref_lod, /*referenced lod*/
141+
LoDTensor *dx) {
142+
size_t hight = ref_lod.size() - 1;
143+
144+
const int kThreadsPerBlock = 32;
145+
int thread_x = kThreadsPerBlock;
146+
int max_threads = context.GetMaxPhysicalThreadCount();
147+
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
148+
149+
dim3 block_size(thread_x);
150+
dim3 grid_size(max_blocks);
151+
152+
sequence_softmax_grad_kernel<
153+
T, kThreadsPerBlock><<<grid_size, block_size, 0, context.stream()>>>(
154+
dout.data<T>(), out.data<T>(), ref_lod.CUDAData(context.GetPlace()),
155+
hight, dx->mutable_data<T>(context.GetPlace()));
156+
}
157+
};
158+
159+
} // namespace operators
160+
} // namespace paddle
161+
162+
namespace ops = paddle::operators;
163+
REGISTER_OP_CUDA_KERNEL(
164+
sequence_softmax,
165+
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, float>,
166+
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, double>);
167+
REGISTER_OP_CUDA_KERNEL(
168+
sequence_softmax_grad,
169+
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext, float>,
170+
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext,
171+
double>);

paddle/fluid/operators/sequence_softmax_op.cu.cc

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)