Skip to content

Commit d7bbd3b

Browse files
committed
fix(kernel): 设置 sharedMemory
Signed-off-by: YdrMaster <[email protected]>
1 parent 7286459 commit d7bbd3b

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/04kernel/src/kernels/attention/cuda_kernel.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ namespace refactor::kernel {
2626
// gridDim.x = batch * nHead
2727
// gridDim.y = seqLen
2828
// blockDim.x = min(1024, attLen)
29+
// sizeof(shared) = attLen * sizeof(float)
2930
template<class T>
3031
static __global__ void softmax(
3132
T *__restrict__ att,
@@ -154,7 +155,9 @@ namespace refactor::kernel {
154155
workspaceQK, d->workspaceSizeQK,
155156
cudaStreamLegacy);
156157
}
157-
softmax<<<dim3(info.batch * info.nHead, info.seqLen), info.seqLen>>>(
158+
softmax<<<dim3(info.batch * info.nHead, info.seqLen),
159+
info.seqLen,
160+
info.seqLen * sizeof(float)>>>(
158161
att, causualMask, info.seqLen, info.seqLen);
159162
{
160163
half alpha = 1, beta = 0;

0 commit comments

Comments
 (0)