Skip to content

Commit 20a34ac

Browse files
committed
feat(kernel): 完成不带 kv cache 的简单 attention
Signed-off-by: YdrMaster <[email protected]>
1 parent ccf293f commit 20a34ac

File tree

1 file changed

+43
-41
lines changed

1 file changed

+43
-41
lines changed

src/04kernel/src/kernels/attention/cuda_kernel.cu

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,21 @@ namespace refactor::kernel {
4040
shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__;
4141
}
4242

43-
// float local_max = -1e20;
44-
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
45-
// local_max = fmaxf(local_max, smem[i]);
46-
// }
47-
// local_max = functions::blockReduceMax<float>(local_max);
43+
float localMax = -1e20;
44+
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
45+
localMax = cub::Max()(localMax, shared[i]);
46+
}
47+
localMax = cuda::blockReduce(localMax, -1e20f, cub::Max());
4848

49-
// float local_sum = 1e-20;
50-
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
51-
// float v = expf(float(smem[i]) - local_max);
52-
// smem[i] = v;
53-
// local_sum += v;
54-
// }
55-
// local_sum = functions::blockReduceSum<float>(local_sum);
56-
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
57-
// x[offset + i] = float(smem[i]) / local_sum;
58-
// }
49+
float localSum = 1e-20;
50+
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
51+
localSum += shared[i] = expf(shared[i] - localMax);
52+
}
53+
localSum = cuda::blockReduce(localSum, 1e-20f, cub::Sum());
54+
auto reciprocal = fdividef(1, localSum);
55+
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
56+
att[i] = shared[i] * reciprocal;
57+
}
5958
}
6059

6160
RoutineWorkspace K::lower(Resources &res) const {
@@ -141,35 +140,38 @@ namespace refactor::kernel {
141140
auto att = reinterpret_cast<half *>(workspace);
142141
auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(d->attSize, 256);
143142
auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256);
144-
145-
float alpha = 1, beta = 0;
146-
cublasLtMatmul(
147-
handle, d->mul.get(),
148-
&alpha,
149-
q, d->q.get(),
150-
k, d->k.get(),
151-
&beta,
152-
att, d->att.get(),
153-
att, d->att.get(),
154-
&d->algoQK,
155-
workspaceQK, d->workspaceSizeQK,
156-
cudaStreamLegacy);
157-
143+
{
144+
half alpha = rsqrtf(info.headDim), beta = 0;
145+
cublasLtMatmul(
146+
handle, d->mul.get(),
147+
&alpha,
148+
q, d->q.get(),
149+
k, d->k.get(),
150+
&beta,
151+
att, d->att.get(),
152+
att, d->att.get(),
153+
&d->algoQK,
154+
workspaceQK, d->workspaceSizeQK,
155+
cudaStreamLegacy);
156+
}
158157
softmax<<<dim3(info.batch * info.nHead, info.seqLen), info.seqLen>>>(
159158
att, causualMask, info.seqLen, info.seqLen);
160-
161-
cublasLtMatmul(
162-
handle, d->mul.get(),
163-
&alpha,
164-
att, d->att.get(),
165-
v, d->v.get(),
166-
&beta,
167-
o, d->q.get(),
168-
o, d->q.get(),
169-
&d->algoAV,
170-
workspaceAV, d->workspaceSizeAV,
171-
cudaStreamLegacy);
159+
{
160+
half alpha = 1, beta = 0;
161+
cublasLtMatmul(
162+
handle, d->mul.get(),
163+
&alpha,
164+
att, d->att.get(),
165+
v, d->v.get(),
166+
&beta,
167+
o, d->q.get(),
168+
o, d->q.get(),
169+
&d->algoAV,
170+
workspaceAV, d->workspaceSizeAV,
171+
cudaStreamLegacy);
172+
};
172173
};
174+
173175
return {std::move(routine), workspaceSize};
174176
}
175177
}

0 commit comments

Comments
 (0)