Skip to content

Commit 4799826

Browse files
committed
feat(kernel): 开始实现 softmax
Signed-off-by: YdrMaster <[email protected]>
1 parent 51b8dfc commit 4799826

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

src/04kernel/src/kernels/attention/cuda_kernel.cu

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,39 @@ namespace refactor::kernel {
77
using K = AttentionCuda;
88
using namespace cublas;
99

10-
static __forceinline__ __device__ bool mask(int tokid, int posid) {
11-
return true;
10+
// 因果系统的注意力遮罩。
11+
// tokenId: 第几个词
12+
// seqLen: 此次处理的词数
13+
// posId: 在 kv cache 中的位置
14+
// attLen = pastSeqLen + seqLen
15+
static __forceinline__ __device__ bool
16+
causualMask(int tokenId, int seqLen,
17+
int posId, int attLen) {
18+
// tokenId ↓ |<---attLen---->|
19+
// 0 | * * ... * |
20+
// 1 | * * ... * * |
21+
// 2 | * * ... * * * |
22+
// seqLen: 3 |---------------|
23+
return attLen + tokenId >= posId + seqLen;
1224
}
1325

1426
// gridDim.x = batch * nHead
1527
// gridDim.y = seqLen
16-
template<class T, class Mask>
28+
// blockDim.x = min(1024, attLen)
29+
template<class T>
1730
static __global__ void softmax(
18-
T *__restrict__ attention,
19-
Mask mask,
20-
uint32_t seqLen,
31+
T *__restrict__ att,
32+
bool (*mask)(int, int, int, int),
33+
uint32_t attLen,
2134
uint32_t bufLen) {
22-
// int offset = (blockIdx.x * len_q + blockIdx.y) * len_buf;
23-
// SharedMemory<float> shared;
24-
// float *smem = shared.getPointer();
35+
// 找到这个线程块对应的 attention 区域
36+
att += (blockIdx.x * gridDim.x + gridDim.y) * bufLen;
37+
// 将输入装入共享内存并 cast + mask
38+
extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen
39+
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
40+
shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__;
41+
}
2542

26-
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
27-
// T pb = (position_bias == nullptr) ? T(0.) : position_bias[offset + i];
28-
// smem[i] = mask[blockIdx.y * len_buf + i] > 0 ? x[offset + i] * scale + pb : -Inf<T>();
29-
// }
3043
// float local_max = -1e20;
3144
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
3245
// local_max = fmaxf(local_max, smem[i]);
@@ -125,7 +138,7 @@ namespace refactor::kernel {
125138
auto k = inputs[1];
126139
auto v = inputs[2];
127140
auto o = outputs[0];
128-
auto att = workspace;
141+
auto att = reinterpret_cast<half *>(workspace);
129142
auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(d->attSize, 256);
130143
auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256);
131144

@@ -143,7 +156,7 @@ namespace refactor::kernel {
143156
cudaStreamLegacy);
144157

145158
softmax<<<dim3(info.batch * info.nHead, info.seqLen), info.seqLen>>>(
146-
att, mask, info.seqLen, info.seqLen);
159+
att, causualMask, info.seqLen, info.seqLen);
147160

148161
cublasLtMatmul(
149162
handle, d->mul.get(),

0 commit comments

Comments
 (0)