|
| 1 | +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include <algorithm> |
| 16 | +#include <cub/cub.cuh> // NOLINT |
| 17 | +#include "paddle/fluid/operators/sequence_softmax_op.h" |
| 18 | + |
| 19 | +namespace paddle { |
| 20 | +namespace operators { |
| 21 | + |
| 22 | +using LoDTensor = framework::LoDTensor; |
| 23 | + |
| 24 | +__device__ __forceinline__ float real_exp(float x) { return expf(x); } |
| 25 | +__device__ __forceinline__ double real_exp(double x) { return exp(x); } |
| 26 | + |
| 27 | +template <typename T, int BlockDim> |
| 28 | +using BlockReduce = cub::BlockReduce<T, BlockDim>; |
| 29 | + |
| 30 | +template <typename T, int BlockDim> |
| 31 | +using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage; |
| 32 | + |
| 33 | +template <typename T, int BlockDim> |
| 34 | +__global__ void sequence_softmax_kernel(const T *in_data, const size_t *ref_lod, |
| 35 | + const size_t src_hight, T *out_data) { |
| 36 | + __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage; |
| 37 | + __shared__ T shared_max_data; |
| 38 | + __shared__ T shared_sum_data; |
| 39 | + |
| 40 | + for (int i = blockIdx.x; i < src_hight; i += gridDim.x) { |
| 41 | + size_t start = ref_lod[i]; |
| 42 | + size_t span = ref_lod[i + 1] - start; |
| 43 | + |
| 44 | + // Find the max ele |
| 45 | + T max_ele = -FLT_MAX; |
| 46 | + for (int tid = threadIdx.x; tid < span; tid += blockDim.x) { |
| 47 | + T ele = in_data[start + tid]; |
| 48 | + max_ele = max_ele > ele ? max_ele : ele; |
| 49 | + } |
| 50 | + max_ele = |
| 51 | + BlockReduce<T, BlockDim>(temp_storage).Reduce(max_ele, cub::Max()); |
| 52 | + if (threadIdx.x == 0) { |
| 53 | + shared_max_data = max_ele; |
| 54 | + } |
| 55 | + __syncthreads(); |
| 56 | + |
| 57 | + // sum |
| 58 | + T sum_data = 0; |
| 59 | + for (int tid = threadIdx.x; tid < span; tid += blockDim.x) { |
| 60 | + T ele = in_data[start + tid]; |
| 61 | + sum_data += real_exp(ele - shared_max_data); |
| 62 | + } |
| 63 | + sum_data = |
| 64 | + BlockReduce<T, BlockDim>(temp_storage).Reduce(sum_data, cub::Sum()); |
| 65 | + if (threadIdx.x == 0) { |
| 66 | + shared_sum_data = sum_data; |
| 67 | + } |
| 68 | + __syncthreads(); |
| 69 | + |
| 70 | + // get final resit |
| 71 | + for (int tid = threadIdx.x; tid < span; tid += blockDim.x) { |
| 72 | + T ele = in_data[start + tid]; |
| 73 | + ele = real_exp(ele - shared_max_data) / shared_sum_data; |
| 74 | + out_data[start + tid] = ele; |
| 75 | + } |
| 76 | + } |
| 77 | +} |
| 78 | + |
| 79 | +template <typename T, int BlockDim> |
| 80 | +__global__ void sequence_softmax_grad_kernel(const T *softmax_grad_data, |
| 81 | + const T *softmax_data, |
| 82 | + const size_t *ref_lod, |
| 83 | + const size_t src_hight, |
| 84 | + T *dx_data) { |
| 85 | + __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage; |
| 86 | + __shared__ T shared_data; |
| 87 | + |
| 88 | + for (int i = blockIdx.x; i < src_hight; i += gridDim.x) { |
| 89 | + size_t start = ref_lod[i]; |
| 90 | + size_t span = ref_lod[i + 1] - start; |
| 91 | + |
| 92 | + T result = 0; |
| 93 | + for (int tid = threadIdx.x; tid < span; tid += blockDim.x) { |
| 94 | + size_t idx = start + tid; |
| 95 | + T s_g_d = softmax_grad_data[idx]; |
| 96 | + T s_d = softmax_data[idx]; |
| 97 | + result += s_g_d * s_d; |
| 98 | + } |
| 99 | + result = BlockReduce<T, BlockDim>(temp_storage).Reduce(result, cub::Sum()); |
| 100 | + if (threadIdx.x == 0) { |
| 101 | + shared_data = result; |
| 102 | + } |
| 103 | + __syncthreads(); |
| 104 | + |
| 105 | + for (int tid = threadIdx.x; tid < span; tid += blockDim.x) { |
| 106 | + size_t idx = start + tid; |
| 107 | + T s_g_d = softmax_grad_data[idx]; |
| 108 | + T s_d = softmax_data[idx]; |
| 109 | + dx_data[idx] = (s_g_d - shared_data) * s_d; |
| 110 | + } |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +template <typename T> |
| 115 | +struct SequenceSoftmaxFunctor<platform::CUDADeviceContext, T> { |
| 116 | + void operator()(const platform::CUDADeviceContext &context, |
| 117 | + const LoDTensor &x, |
| 118 | + const framework::Vector<size_t> &ref_lod, /*referenced lod*/ |
| 119 | + LoDTensor *out) { |
| 120 | + int hight = ref_lod.size() - 1; |
| 121 | + |
| 122 | + const int kThreadsPerBlock = 32; |
| 123 | + int thread_x = kThreadsPerBlock; |
| 124 | + int max_threads = context.GetMaxPhysicalThreadCount(); |
| 125 | + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); |
| 126 | + |
| 127 | + dim3 block_size(thread_x); |
| 128 | + dim3 grid_size(max_blocks); |
| 129 | + sequence_softmax_kernel< |
| 130 | + T, kThreadsPerBlock><<<grid_size, block_size, 0, context.stream()>>>( |
| 131 | + x.data<T>(), ref_lod.CUDAData(context.GetPlace()), hight, |
| 132 | + out->mutable_data<T>(context.GetPlace())); |
| 133 | + } |
| 134 | +}; |
| 135 | + |
| 136 | +template <typename T> |
| 137 | +struct SequenceSoftmaxGradFunctor<platform::CUDADeviceContext, T> { |
| 138 | + void operator()(const platform::CUDADeviceContext &context, |
| 139 | + const LoDTensor &dout, const LoDTensor &out, |
| 140 | + const framework::Vector<size_t> &ref_lod, /*referenced lod*/ |
| 141 | + LoDTensor *dx) { |
| 142 | + size_t hight = ref_lod.size() - 1; |
| 143 | + |
| 144 | + const int kThreadsPerBlock = 32; |
| 145 | + int thread_x = kThreadsPerBlock; |
| 146 | + int max_threads = context.GetMaxPhysicalThreadCount(); |
| 147 | + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); |
| 148 | + |
| 149 | + dim3 block_size(thread_x); |
| 150 | + dim3 grid_size(max_blocks); |
| 151 | + |
| 152 | + sequence_softmax_grad_kernel< |
| 153 | + T, kThreadsPerBlock><<<grid_size, block_size, 0, context.stream()>>>( |
| 154 | + dout.data<T>(), out.data<T>(), ref_lod.CUDAData(context.GetPlace()), |
| 155 | + hight, dx->mutable_data<T>(context.GetPlace())); |
| 156 | + } |
| 157 | +}; |
| 158 | + |
| 159 | +} // namespace operators |
| 160 | +} // namespace paddle |
| 161 | + |
| 162 | +namespace ops = paddle::operators; |
| 163 | +REGISTER_OP_CUDA_KERNEL( |
| 164 | + sequence_softmax, |
| 165 | + ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, float>, |
| 166 | + ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, double>); |
| 167 | +REGISTER_OP_CUDA_KERNEL( |
| 168 | + sequence_softmax_grad, |
| 169 | + ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext, float>, |
| 170 | + ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext, |
| 171 | + double>); |
0 commit comments