Skip to content

Commit 51b8dfc

Browse files
committed
feat(kernel): 确认供 attention 调用的 softmax 接口
Signed-off-by: YdrMaster <[email protected]>
1 parent 089513b commit 51b8dfc

File tree

3 files changed

+59
-5
lines changed

3 files changed

+59
-5
lines changed

src/04kernel/src/kernels/attention/cuda_kernel.cu

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,44 @@ namespace refactor::kernel {
77
using K = AttentionCuda;
88
using namespace cublas;
99

10+
static __forceinline__ __device__ bool mask(int tokid, int posid) {
11+
return true;
12+
}
13+
14+
// gridDim.x = batch * nHead
15+
// gridDim.y = seqLen
16+
template<class T, class Mask>
17+
static __global__ void softmax(
18+
T *__restrict__ attention,
19+
Mask mask,
20+
uint32_t seqLen,
21+
uint32_t bufLen) {
22+
// int offset = (blockIdx.x * len_q + blockIdx.y) * len_buf;
23+
// SharedMemory<float> shared;
24+
// float *smem = shared.getPointer();
25+
26+
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
27+
// T pb = (position_bias == nullptr) ? T(0.) : position_bias[offset + i];
28+
// smem[i] = mask[blockIdx.y * len_buf + i] > 0 ? x[offset + i] * scale + pb : -Inf<T>();
29+
// }
30+
// float local_max = -1e20;
31+
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
32+
// local_max = fmaxf(local_max, smem[i]);
33+
// }
34+
// local_max = functions::blockReduceMax<float>(local_max);
35+
36+
// float local_sum = 1e-20;
37+
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
38+
// float v = expf(float(smem[i]) - local_max);
39+
// smem[i] = v;
40+
// local_sum += v;
41+
// }
42+
// local_sum = functions::blockReduceSum<float>(local_sum);
43+
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
44+
// x[offset + i] = float(smem[i]) / local_sum;
45+
// }
46+
}
47+
1048
RoutineWorkspace K::lower(Resources &res) const {
1149
auto handle = res.fetchOrStore<CublasLtContext>()->handle;
1250

@@ -23,9 +61,9 @@ namespace refactor::kernel {
2361
size_t attSize, workspaceSizeQK, workspaceSizeAV;
2462

2563
Descriptors(CublasLtContext const &context,
26-
cublasComputeType_t compute,
2764
AttentionInfo info)
28-
: mul(compute, CUDA_R_32F),
65+
: mul(computeTypeConvert(info.dataType),
66+
dataTypeConvert(info.dataType)),
2967
q(MatrixLayout{
3068
.dataType = dataTypeConvert(info.dataType),
3169
.rows = static_cast<uint64_t>(info.seqLen),
@@ -73,11 +111,10 @@ namespace refactor::kernel {
73111
};
74112

75113
auto const &context = *res.fetchOrStore<CublasLtContext>();
76-
auto d = std::make_shared<Descriptors>(context, CUBLAS_COMPUTE_32F, info);
114+
auto d = std::make_shared<Descriptors>(context, info);
77115
auto workspaceSize = d->attSize;
78116
workspaceSize = hardware::alignBytes(workspaceSize, 256);
79117
workspaceSize += d->workspaceSizeQK;
80-
workspaceSize = hardware::alignBytes(workspaceSize, 256);
81118
workspaceSize += d->workspaceSizeAV;
82119
workspaceSize = hardware::alignBytes(workspaceSize, 256);
83120

@@ -105,7 +142,8 @@ namespace refactor::kernel {
105142
workspaceQK, d->workspaceSizeQK,
106143
cudaStreamLegacy);
107144

108-
// TODO inline mask && softmax
145+
softmax<<<dim3(info.batch * info.nHead, info.seqLen), info.seqLen>>>(
146+
att, mask, info.seqLen, info.seqLen);
109147

110148
cublasLtMatmul(
111149
handle, d->mul.get(),

src/04kernel/src/utilities/cuda/cublaslt_utils.cu

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ namespace refactor::kernel::cublas {
3434
switch (dt) {
3535
case DataType::F32:
3636
return CUDA_R_32F;
37+
case DataType::FP16:
38+
return CUDA_R_16F;
39+
case DataType::BF16:
40+
return CUDA_R_16BF;
41+
default:
42+
TODO("");
43+
}
44+
}
45+
cublasComputeType_t computeTypeConvert(DataType dt) {
46+
switch (dt) {
47+
case DataType::F32:
48+
case DataType::BF16:
49+
return CUBLAS_COMPUTE_32F;
50+
case DataType::FP16:
51+
return CUBLAS_COMPUTE_16F;
3752
default:
3853
TODO("");
3954
}

src/04kernel/src/utilities/cuda/cublaslt_utils.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace refactor::kernel::cublas {
3030
};
3131

3232
cudaDataType dataTypeConvert(DataType);
33+
cublasComputeType_t computeTypeConvert(DataType);
3334

3435
class MatMulDescriptor {
3536
cublasLtMatmulDesc_t _internal;

0 commit comments

Comments
 (0)