Skip to content

Commit 089513b

Browse files
committed
refactor(kernel): 实现一种不依赖模板参数的 BlockReduce 并用于 softmax
Signed-off-by: YdrMaster <[email protected]>
1 parent 95c991a commit 089513b

File tree

3 files changed

+41
-14
lines changed

3 files changed

+41
-14
lines changed

src/04kernel/cuda/include/kernel/cuda/reduce.cuh

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,28 @@
44
#include <cub/warp/warp_reduce.cuh>
55

66
namespace refactor::kernel::cuda {
7-
}
7+
8+
template<class T, class ReductionOp>
9+
__inline__ __device__ T blockReduce(T x, T init, ReductionOp op) {
10+
using WarpReduce = cub::WarpReduce<T>;
11+
__shared__ typename WarpReduce::TempStorage tempStorage;
12+
__shared__ T shared[32], ans;
13+
14+
auto reduce = WarpReduce(tempStorage);
15+
int lane = threadIdx.x % 32;
16+
int wid = threadIdx.x / 32;
17+
x = reduce.Reduce(x, op);
18+
if (lane == 0) { shared[wid] = x; }
19+
__syncthreads();
20+
if (wid == 0) {
21+
x = (threadIdx.x < blockDim.x / 32) ? shared[lane] : init;
22+
shared[lane] = reduce.Reduce(x, op);
23+
if (lane == 0) { ans = shared[0]; }
24+
}
25+
__syncthreads();
26+
return ans;// avoid RAW hazard
27+
}
28+
29+
}// namespace refactor::kernel::cuda
830

931
#endif// KERNEL_CUDA_REDUCE_CUH

src/04kernel/src/kernels/softmax/cuda_kernel.cu

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "cuda_kernel.hh"
2-
#include <cub/cub.cuh>
2+
#include "kernel/cuda/reduce.cuh"
33

44
namespace refactor::kernel {
55
using namespace runtime;
@@ -18,8 +18,8 @@ namespace refactor::kernel {
1818
template<> __device__ __forceinline__ nv_bfloat16 reciprocal<nv_bfloat16>(nv_bfloat16 x) { return hrcp(x); }
1919

2020
// blockDim.x === BLOCK_DIM
21-
template<int BLOCK_DIM, class T>
22-
__launch_bounds__(BLOCK_DIM) __global__ void blockSoftmaxKernel(
21+
template<class T>
22+
__global__ void blockSoftmaxKernel(
2323
T const *__restrict x,
2424
T *__restrict y,
2525
int mid,
@@ -40,10 +40,8 @@ namespace refactor::kernel {
4040
for (int i = threadIdx.x + blockDim.x; i < mid; i += blockDim.x) {
4141
maxSumThread = MaxSum::reduce(maxSumThread, {x[id + i * stride], 1});// reduce the data to one block
4242
}
43-
using BlockReduce = cub::BlockReduce<MaxSum, BLOCK_DIM>;
44-
__shared__ typename BlockReduce::TempStorage tempStorage;
4543
__shared__ MaxSum maxSumTotal;
46-
auto maxSumBlock = BlockReduce(tempStorage).Reduce(maxSumThread, MaxSum::reduce);
44+
auto maxSumBlock = cuda::blockReduce(maxSumThread, {-__FLT_MAX__, 0}, MaxSum::reduce);
4745
if (threadIdx.x == 0) {
4846
maxSumTotal = maxSumBlock;// must set threadIdx.x = 0 write the output to memory
4947
}
@@ -113,7 +111,7 @@ namespace refactor::kernel {
113111
auto y = reinterpret_cast<T *>(outputs[0]);
114112
int numBlocks = info.pre * info.post;
115113
if (info.mid > 1024) {
116-
blockSoftmaxKernel<1024><<<numBlocks, 1024>>>(x, y, info.mid, info.post);
114+
blockSoftmaxKernel<<<numBlocks, 1024>>>(x, y, info.mid, info.post);
117115
} else {
118116
int blockDimX, mid = static_cast<int>(info.mid);
119117
for (blockDimX = 32; blockDimX > 4 && mid < blockDimX; blockDimX /= 2) {}

src/04kernel/test/kernels/softmax/test_cuda.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
#include "../../../src/kernels/softmax/cuda_kernel.hh"
55
#include "hardware/device_manager.h"
66
#include <gtest/gtest.h>
7+
#include <numeric>
78

89
using namespace refactor;
910
using namespace kernel;
1011
using namespace hardware;
1112

12-
TEST(kernel, SoftmaxCuda) {
13+
static void test(Shape shape, int axis) {
1314
// build routine
14-
auto xTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4});
15-
auto outTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4});
16-
dim_t axis = 1;
17-
auto kCpu = SoftmaxCpu::build(SoftmaxInfo(*xTensor, axis));
18-
auto kCuda = SoftmaxCuda::build(SoftmaxInfo(*xTensor, axis));
15+
auto xTensor = Tensor::share(DataType::F32, shape);
16+
auto outTensor = Tensor::share(DataType::F32, shape);
17+
SoftmaxInfo info(*xTensor, axis);
18+
auto kCpu = SoftmaxCpu::build(info);
19+
auto kCuda = SoftmaxCuda::build(info);
1920
ASSERT_TRUE(kCpu && kCuda);
2021
auto res = runtime::Resources();
2122
auto rCpu = kCpu->lower(res).routine;
@@ -28,6 +29,7 @@ TEST(kernel, SoftmaxCuda) {
2829
std::vector<float>
2930
data(xTensor->elementsSize(), 0),
3031
cpuOut(outTensor->elementsSize());
32+
std::iota(data.begin(), data.end(), 0);
3133
gpuIn->copyFromHost(data.data(), xTensor->bytesSize());
3234
// inference
3335
{
@@ -49,4 +51,9 @@ TEST(kernel, SoftmaxCuda) {
4951
}
5052
}
5153

54+
TEST(kernel, SoftmaxCuda) {
55+
test({2, 3, 2, 5, 4}, 1);
56+
test({2, 2048, 2, 5, 4}, 1);
57+
}
58+
5259
#endif

0 commit comments

Comments
 (0)