We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7286459 commit d7bbd3bCopy full SHA for d7bbd3b
src/04kernel/src/kernels/attention/cuda_kernel.cu
@@ -26,6 +26,7 @@ namespace refactor::kernel {
26
// gridDim.x = batch * nHead
27
// gridDim.y = seqLen
28
// blockDim.x = min(1024, attLen)
29
+ // sizeof(shared) = attLen * sizeof(float)
30
template<class T>
31
static __global__ void softmax(
32
T *__restrict__ att,
@@ -154,7 +155,9 @@ namespace refactor::kernel {
154
155
workspaceQK, d->workspaceSizeQK,
156
cudaStreamLegacy);
157
}
- softmax<<<dim3(info.batch * info.nHead, info.seqLen), info.seqLen>>>(
158
+ softmax<<<dim3(info.batch * info.nHead, info.seqLen),
159
+ info.seqLen,
160
+ info.seqLen * sizeof(float)>>>(
161
att, causualMask, info.seqLen, info.seqLen);
162
{
163
half alpha = 1, beta = 0;
0 commit comments