Skip to content

Commit 40fcbee

Browse files
committed
feat: 开始实现 attention
Signed-off-by: YdrMaster <[email protected]>
1 parent d076c20 commit 40fcbee

File tree

12 files changed

+409
-105
lines changed

12 files changed

+409
-105
lines changed

src/02hardware/include/hardware/devices/nvidia.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33

44
#include "../device.h"
55

6+
#define CUDA_ASSERT(STATUS) \
7+
if (auto status = (STATUS); status != cudaSuccess) { \
8+
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
9+
cudaGetErrorString(status), (int) status)); \
10+
}
11+
612
namespace refactor::hardware {
713

814
class Nvidia final : public Device {

src/02hardware/src/devices/nvidia/device.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,6 @@
44
#ifdef USE_CUDA
55
#include "memory.hh"
66
#include <cuda_runtime.h>
7-
8-
#define CUDA_ASSERT(STATUS) \
9-
if (auto status = (STATUS); status != cudaSuccess) { \
10-
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
11-
cudaGetErrorString(status), (int) status)); \
12-
}
137
#endif
148

159
namespace refactor::hardware {

src/02hardware/src/devices/nvidia/memory.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
11
#ifdef USE_CUDA
22

33
#include "memory.hh"
4-
#include "common.h"
4+
#include "hardware/devices/nvidia.h"
55
#include <cuda_runtime.h>
66

7-
#define CUDA_ASSERT(STATUS) \
8-
if (auto status = (STATUS); status != cudaSuccess) { \
9-
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
10-
cudaGetErrorString(status), (int) status)); \
11-
}
12-
137
namespace refactor::hardware {
148
using M = NvidiaMemory;
159

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#ifndef KERNEL_ATTENTION_INFO_H
2+
#define KERNEL_ATTENTION_INFO_H
3+
4+
#include "../tensor.h"
5+
6+
namespace refactor::kernel {
7+
8+
struct AttentionInfo {
9+
DataType dataType;
10+
dim_t batch, nHead, nKVHead, seqLen, headDim, cacheLen;
11+
bool concatCache, resetCache;
12+
};
13+
14+
}// namespace refactor::kernel
15+
16+
#endif// KERNEL_ATTENTION_INFO_H

src/04kernel/include/kernel/collectors/attention.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
namespace refactor::kernel {
77

88
struct AttentionCollector final : public InfoCollector {
9-
dim_t maxSeqLen;
109

11-
AttentionCollector(decltype(_target), decltype(maxSeqLen)) noexcept;
10+
AttentionCollector(decltype(_target)) noexcept;
1211

1312
std::vector<KernelBox>
1413
filter(TensorRefs inputs, TensorRefs outputs) const final;

src/04kernel/src/collectors/attention.cc

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,57 @@
11
#include "kernel/collectors/attention.h"
2+
#include "kernel/attributes/attention_info.h"
23
// #include "../kernels/attention/cpu_kernel.hh"
34
#include "../kernels/attention/cuda_kernel.hh"
45

56
namespace refactor::kernel {
67

78
AttentionCollector::AttentionCollector(
8-
decltype(_target) target,
9-
decltype(maxSeqLen) maxSeqLen_) noexcept
10-
: InfoCollector(target),
11-
maxSeqLen(maxSeqLen_) {}
9+
decltype(_target) target) noexcept
10+
: InfoCollector(target) {}
1211

1312
std::vector<KernelBox>
1413
AttentionCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
1514
auto const &query = inputs[0].get();
1615
auto const &key = inputs[1].get();
17-
auto pastSeqLen = inputs.size() == 3 ? 0 : *inputs[2].get().data->get<int64_t>();
18-
auto cacheLen = outputs.size() == 1 ? 0 : outputs[1].get().shape[2];
1916

20-
std::vector<KernelBox> ans;
17+
AttentionInfo info{
18+
.dataType = query.dataType,
19+
.batch = query.shape[0],
20+
.nHead = query.shape[1],
21+
.nKVHead = key.shape[1],
22+
.seqLen = query.shape[2],
23+
.headDim = query.shape[3],
24+
.cacheLen = 0,
25+
.concatCache = false,
26+
.resetCache = false,
27+
};
28+
switch (outputs.size()) {
29+
case 1:
30+
// no kv cache
31+
ASSERT(inputs.size() == 3, "");
32+
break;
33+
case 3:
34+
switch (inputs.size()) {
35+
case 6:
36+
info.resetCache = true;
37+
case 4:
38+
info.concatCache = true;
39+
case 3:
40+
info.cacheLen = outputs[1].get().shape[2];
41+
break;
42+
default:
43+
UNREACHABLE();
44+
}
45+
break;
46+
default:
47+
UNREACHABLE();
48+
}
49+
50+
std ::vector<KernelBox> ans;
2151
switch (_target) {
2252
case decltype(_target)::Cpu:
2353
break;
2454
case decltype(_target)::Nvidia: {
25-
decltype(AttentionCuda::info) info{
26-
.dataType = query.dataType,
27-
.batch = query.shape[0],
28-
.nHead = query.shape[1],
29-
.nKVHead = key.shape[1],
30-
.pastSeqLen = static_cast<dim_t>(pastSeqLen),
31-
.seqLen = query.shape[2],
32-
.cacheLen = cacheLen,
33-
.headDim = query.shape[3],
34-
.resetCache = false,
35-
};
3655
if (auto ptr = AttentionCuda::build(info); ptr) {
3756
ans.emplace_back(std::move(ptr));
3857
}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include "../../utilities/cuda/cublaslt_utils.cuh"
2+
#include "cuda_kernel.hh"
3+
#include "hardware/functions.h"
4+
5+
namespace refactor::kernel {
6+
using K = AttentionCuda;
7+
using namespace cublas;
8+
9+
RoutineWorkspace K::lower(Resources &res) const {
10+
auto handle = res.fetchOrStore<CublasLtContext>()->handle;
11+
12+
constexpr auto ROW_MAJOR = CUBLASLT_ORDER_ROW;
13+
constexpr auto COL_MAJOR = CUBLASLT_ORDER_COL;
14+
15+
if (!info.cacheLen) {
16+
if (info.nHead == info.nKVHead) {
17+
// RAII for closure
18+
struct Descriptors {
19+
MatMulDescriptor mul;
20+
MatrixDescriptor q, k, v, att;
21+
cublasLtMatmulAlgo_t algoQK, algoAV;
22+
size_t attSize, workspaceSizeQK, workspaceSizeAV;
23+
24+
Descriptors(CublasLtContext const &context,
25+
cublasComputeType_t compute,
26+
AttentionInfo info)
27+
: mul(compute, CUDA_R_32F),
28+
q(MatrixLayout{
29+
.dataType = dataTypeConvert(info.dataType),
30+
.rows = static_cast<uint64_t>(info.seqLen),
31+
.cols = static_cast<uint64_t>(info.headDim),
32+
.majorStride = static_cast<int64_t>(info.headDim),
33+
.order = ROW_MAJOR,
34+
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
35+
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
36+
}),
37+
k(MatrixLayout{
38+
.dataType = dataTypeConvert(info.dataType),
39+
.rows = static_cast<uint64_t>(info.headDim),
40+
.cols = static_cast<uint64_t>(info.seqLen),
41+
.majorStride = static_cast<int64_t>(info.headDim),
42+
.order = COL_MAJOR,
43+
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
44+
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
45+
}),
46+
v(MatrixLayout{
47+
.dataType = dataTypeConvert(info.dataType),
48+
.rows = static_cast<uint64_t>(info.seqLen),
49+
.cols = static_cast<uint64_t>(info.headDim),
50+
.majorStride = static_cast<int64_t>(info.headDim),
51+
.order = ROW_MAJOR,
52+
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
53+
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
54+
}),
55+
att(MatrixLayout{
56+
.dataType = dataTypeConvert(info.dataType),
57+
.rows = static_cast<uint64_t>(info.seqLen),
58+
.cols = static_cast<uint64_t>(info.seqLen),
59+
.majorStride = static_cast<int64_t>(info.seqLen),
60+
.order = ROW_MAJOR,
61+
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
62+
.batchStride = static_cast<int64_t>(info.seqLen * info.seqLen),
63+
}),
64+
attSize(info.batch * info.nHead * info.seqLen * info.seqLen * info.dataType.size()) {
65+
auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att);
66+
auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q);
67+
algoQK = algoQK_;
68+
algoAV = algoAV_;
69+
workspaceSizeQK = workspaceSizeQK_;
70+
workspaceSizeAV = workspaceSizeAV_;
71+
}
72+
};
73+
74+
auto const &context = *res.fetchOrStore<CublasLtContext>();
75+
auto d = std::make_shared<Descriptors>(context, CUBLAS_COMPUTE_32F, info);
76+
auto workspaceSize = d->attSize;
77+
workspaceSize = hardware::alignBytes(workspaceSize, 256);
78+
workspaceSize += d->workspaceSizeQK;
79+
workspaceSize = hardware::alignBytes(workspaceSize, 256);
80+
workspaceSize += d->workspaceSizeAV;
81+
workspaceSize = hardware::alignBytes(workspaceSize, 256);
82+
83+
auto routine = [d = std::move(d), info = this->info]//
84+
(Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
85+
auto handle = res.fetchOrStore<CublasLtContext>()->handle;
86+
auto q = inputs[0];
87+
auto k = inputs[1];
88+
auto v = inputs[2];
89+
auto o = outputs[0];
90+
auto att = workspace;
91+
auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(d->attSize, 256);
92+
auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256);
93+
94+
float alpha = 1, beta = 0;
95+
cublasLtMatmul(
96+
handle, d->mul.get(),
97+
&alpha,
98+
q, d->q.get(),
99+
k, d->k.get(),
100+
&beta,
101+
att, d->att.get(),
102+
att, d->att.get(),
103+
&d->algoQK,
104+
workspaceQK, d->workspaceSizeQK,
105+
cudaStreamLegacy);
106+
107+
// TODO inline mask && softmax
108+
109+
cublasLtMatmul(
110+
handle, d->mul.get(),
111+
&alpha,
112+
att, d->att.get(),
113+
v, d->v.get(),
114+
&beta,
115+
o, d->q.get(),
116+
o, d->q.get(),
117+
&d->algoAV,
118+
workspaceAV, d->workspaceSizeAV,
119+
cudaStreamLegacy);
120+
};
121+
return {std::move(routine), workspaceSize};
122+
}
123+
}
124+
TODO("");
125+
}
126+
127+
}// namespace refactor::kernel

src/04kernel/src/kernels/attention/cuda_kernel.hh

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
#ifndef KERNEL_ATTENTION_CUDA_KERNEL_HH
22
#define KERNEL_ATTENTION_CUDA_KERNEL_HH
33

4+
#include "kernel/attributes/attention_info.h"
45
#include "kernel/kernel.h"
5-
#include "kernel/tensor.h"
66

77
namespace refactor::kernel {
88

99
struct AttentionCuda final : public Kernel {
10-
struct {
11-
DataType dataType;
12-
dim_t batch, nHead, nKVHead, pastSeqLen, seqLen, cacheLen, headDim;
13-
bool resetCache;
14-
} info;
10+
AttentionInfo info;
1511

1612
AttentionCuda(decltype(info)) noexcept;
1713

src/04kernel/src/utilities/cuda/cublaslt_context.cu

Lines changed: 0 additions & 33 deletions
This file was deleted.

src/04kernel/src/utilities/cuda/cublaslt_context.hh

Lines changed: 0 additions & 33 deletions
This file was deleted.

0 commit comments

Comments
 (0)