@@ -40,22 +40,21 @@ namespace refactor::kernel {
4040 shared[i] = mask (blockIdx .y , gridDim .y , i, attLen) ? float (att[i]) : -__FLT_MAX__;
4141 }
4242
43- // float local_max = -1e20;
44- // for (int i = threadIdx.x; i < len_buf ; i += blockDim.x) {
45- // local_max = fmaxf(local_max, smem [i]);
46- // }
47- // local_max = functions::blockReduceMax<float>(local_max );
43+ float localMax = -1e20 ;
44+ for (auto i = threadIdx .x ; i < attLen ; i += blockDim .x ) {
45+ localMax = cub::Max ()(localMax, shared [i]);
46+ }
47+ localMax = cuda::blockReduce (localMax, -1e20f, cub::Max () );
4848
49- // float local_sum = 1e-20;
50- // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
51- // float v = expf(float(smem[i]) - local_max);
52- // smem[i] = v;
53- // local_sum += v;
54- // }
55- // local_sum = functions::blockReduceSum<float>(local_sum);
56- // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
57- // x[offset + i] = float(smem[i]) / local_sum;
58- // }
49+ float localSum = 1e-20 ;
50+ for (auto i = threadIdx .x ; i < attLen; i += blockDim .x ) {
51+ localSum += shared[i] = expf (shared[i] - localMax);
52+ }
53+ localSum = cuda::blockReduce (localSum, 1e-20f , cub::Sum ());
54+ auto reciprocal = fdividef (1 , localSum);
55+ for (auto i = threadIdx .x ; i < attLen; i += blockDim .x ) {
56+ att[i] = shared[i] * reciprocal;
57+ }
5958 }
6059
6160 RoutineWorkspace K::lower (Resources &res) const {
@@ -141,35 +140,38 @@ namespace refactor::kernel {
141140 auto att = reinterpret_cast <half *>(workspace);
142141 auto workspaceQK = reinterpret_cast <uint8_t *>(workspace) + hardware::alignBytes (d->attSize , 256 );
143142 auto workspaceAV = workspaceQK + hardware::alignBytes (d->workspaceSizeQK , 256 );
144-
145- float alpha = 1 , beta = 0 ;
146- cublasLtMatmul (
147- handle, d->mul .get (),
148- &alpha,
149- q, d->q .get (),
150- k, d->k .get (),
151- &beta,
152- att, d->att .get (),
153- att, d->att .get (),
154- &d->algoQK ,
155- workspaceQK, d->workspaceSizeQK ,
156- cudaStreamLegacy);
157-
143+ {
144+ half alpha = rsqrtf (info. headDim ) , beta = 0 ;
145+ cublasLtMatmul (
146+ handle, d->mul .get (),
147+ &alpha,
148+ q, d->q .get (),
149+ k, d->k .get (),
150+ &beta,
151+ att, d->att .get (),
152+ att, d->att .get (),
153+ &d->algoQK ,
154+ workspaceQK, d->workspaceSizeQK ,
155+ cudaStreamLegacy);
156+ }
158157 softmax<<<dim3 (info.batch * info.nHead, info.seqLen), info.seqLen>>> (
159158 att, causualMask, info.seqLen , info.seqLen );
160-
161- cublasLtMatmul (
162- handle, d->mul .get (),
163- &alpha,
164- att, d->att .get (),
165- v, d->v .get (),
166- &beta,
167- o, d->q .get (),
168- o, d->q .get (),
169- &d->algoAV ,
170- workspaceAV, d->workspaceSizeAV ,
171- cudaStreamLegacy);
159+ {
160+ half alpha = 1 , beta = 0 ;
161+ cublasLtMatmul (
162+ handle, d->mul .get (),
163+ &alpha,
164+ att, d->att .get (),
165+ v, d->v .get (),
166+ &beta,
167+ o, d->q .get (),
168+ o, d->q .get (),
169+ &d->algoAV ,
170+ workspaceAV, d->workspaceSizeAV ,
171+ cudaStreamLegacy);
172+ };
172173 };
174+
173175 return {std::move (routine), workspaceSize};
174176 }
175177 }
0 commit comments