@@ -7,6 +7,44 @@ namespace refactor::kernel {
77 using K = AttentionCuda;
88 using namespace cublas ;
99
10+ static __forceinline__ __device__ bool mask (int tokid, int posid) {
11+ return true ;
12+ }
13+
14+ // gridDim.x = batch * nHead
15+ // gridDim.y = seqLen
16+ template <class T , class Mask >
17+ static __global__ void softmax (
18+ T *__restrict__ attention,
19+ Mask mask,
20+ uint32_t seqLen,
21+ uint32_t bufLen) {
22+ // int offset = (blockIdx.x * len_q + blockIdx.y) * len_buf;
23+ // SharedMemory<float> shared;
24+ // float *smem = shared.getPointer();
25+
26+ // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
27+ // T pb = (position_bias == nullptr) ? T(0.) : position_bias[offset + i];
28+ // smem[i] = mask[blockIdx.y * len_buf + i] > 0 ? x[offset + i] * scale + pb : -Inf<T>();
29+ // }
30+ // float local_max = -1e20;
31+ // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
32+ // local_max = fmaxf(local_max, smem[i]);
33+ // }
34+ // local_max = functions::blockReduceMax<float>(local_max);
35+
36+ // float local_sum = 1e-20;
37+ // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
38+ // float v = expf(float(smem[i]) - local_max);
39+ // smem[i] = v;
40+ // local_sum += v;
41+ // }
42+ // local_sum = functions::blockReduceSum<float>(local_sum);
43+ // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
44+ // x[offset + i] = float(smem[i]) / local_sum;
45+ // }
46+ }
47+
1048 RoutineWorkspace K::lower (Resources &res) const {
1149 auto handle = res.fetchOrStore <CublasLtContext>()->handle ;
1250
@@ -23,9 +61,9 @@ namespace refactor::kernel {
2361 size_t attSize, workspaceSizeQK, workspaceSizeAV;
2462
2563 Descriptors (CublasLtContext const &context,
26- cublasComputeType_t compute,
2764 AttentionInfo info)
28- : mul(compute, CUDA_R_32F),
65+ : mul(computeTypeConvert(info.dataType),
66+ dataTypeConvert (info.dataType)),
2967 q(MatrixLayout{
3068 .dataType = dataTypeConvert (info.dataType ),
3169 .rows = static_cast <uint64_t >(info.seqLen ),
@@ -73,11 +111,10 @@ namespace refactor::kernel {
73111 };
74112
75113 auto const &context = *res.fetchOrStore<CublasLtContext>();
76- auto d = std::make_shared<Descriptors>(context, CUBLAS_COMPUTE_32F, info);
114+ auto d = std::make_shared<Descriptors>(context, info);
77115 auto workspaceSize = d->attSize;
78116 workspaceSize = hardware::alignBytes(workspaceSize, 256 );
79117 workspaceSize += d->workspaceSizeQK;
80- workspaceSize = hardware::alignBytes(workspaceSize, 256 );
81118 workspaceSize += d->workspaceSizeAV;
82119 workspaceSize = hardware::alignBytes(workspaceSize, 256 );
83120
@@ -105,7 +142,8 @@ namespace refactor::kernel {
105142 workspaceQK, d->workspaceSizeQK ,
106143 cudaStreamLegacy);
107144
108- // TODO inline mask && softmax
145+ softmax<<<dim3 (info.batch * info.nHead, info.seqLen), info.seqLen>>> (
146+ att, mask, info.seqLen , info.seqLen );
109147
110148 cublasLtMatmul (
111149 handle, d->mul .get (),
0 commit comments