@@ -7,26 +7,39 @@ namespace refactor::kernel {
77 using K = AttentionCuda;
88 using namespace cublas ;
99
10- static __forceinline__ __device__ bool mask (int tokid, int posid) {
11- return true ;
10+ // 因果系统的注意力遮罩。
11+ // tokenId: 第几个词
12+ // seqLen: 此次处理的词数
13+ // posId: 在 kv cache 中的位置
14+ // attLen = pastSeqLen + seqLen
15+ static __forceinline__ __device__ bool
16+ causualMask (int tokenId, int seqLen,
17+ int posId, int attLen) {
18+ // tokenId ↓ |<---attLen---->|
19+ // 0 | * * ... * |
20+ // 1 | * * ... * * |
21+ // 2 | * * ... * * * |
22+ // seqLen: 3 |---------------|
23+ return attLen + tokenId >= posId + seqLen;
1224 }
1325
1426 // gridDim.x = batch * nHead
1527 // gridDim.y = seqLen
16- template <class T , class Mask >
28+ // blockDim.x = min(1024, attLen)
29+ template <class T >
1730 static __global__ void softmax (
18- T *__restrict__ attention ,
19- Mask mask,
20- uint32_t seqLen ,
31+ T *__restrict__ att ,
32+ bool (* mask)( int , int , int , int ) ,
33+ uint32_t attLen ,
2134 uint32_t bufLen) {
22- // int offset = (blockIdx.x * len_q + blockIdx.y) * len_buf;
23- // SharedMemory<float> shared;
24- // float *smem = shared.getPointer();
35+ // 找到这个线程块对应的 attention 区域
36+ att += (blockIdx .x * gridDim .x + gridDim .y ) * bufLen;
37+ // 将输入装入共享内存并 cast + mask
38+ extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen
39+ for (auto i = threadIdx .x ; i < attLen; i += blockDim .x ) {
40+ shared[i] = mask (blockIdx .y , gridDim .y , i, attLen) ? float (att[i]) : -__FLT_MAX__;
41+ }
2542
26- // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
27- // T pb = (position_bias == nullptr) ? T(0.) : position_bias[offset + i];
28- // smem[i] = mask[blockIdx.y * len_buf + i] > 0 ? x[offset + i] * scale + pb : -Inf<T>();
29- // }
3043 // float local_max = -1e20;
3144 // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
3245 // local_max = fmaxf(local_max, smem[i]);
@@ -125,7 +138,7 @@ namespace refactor::kernel {
125138 auto k = inputs[1 ];
126139 auto v = inputs[2 ];
127140 auto o = outputs[0 ];
128- auto att = workspace;
141+ auto att = reinterpret_cast <half *>( workspace) ;
129142 auto workspaceQK = reinterpret_cast <uint8_t *>(workspace) + hardware::alignBytes (d->attSize , 256 );
130143 auto workspaceAV = workspaceQK + hardware::alignBytes (d->workspaceSizeQK , 256 );
131144
@@ -143,7 +156,7 @@ namespace refactor::kernel {
143156 cudaStreamLegacy);
144157
145158 softmax<<<dim3 (info.batch * info.nHead, info.seqLen), info.seqLen>>> (
146- att, mask , info.seqLen , info.seqLen );
159+ att, causualMask , info.seqLen , info.seqLen );
147160
148161 cublasLtMatmul (
149162 handle, d->mul .get (),
0 commit comments