Skip to content

Commit 49bb7d8

Browse files
committed
perf(llama.cu): 更新 flash attn 提升性能
Signed-off-by: YdrMaster <[email protected]>
1 parent 6063373 commit 49bb7d8

File tree

5 files changed

+57
-73
lines changed

5 files changed

+57
-73
lines changed

Cargo.lock

Lines changed: 35 additions & 34 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llama.cu/Cargo.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ version = "0.0.0"
44
edition.workspace = true
55

66
[dependencies]
7-
cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "31c8090" }
8-
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "31c8090" }
9-
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "31c8090" }
10-
flash-attn = { git = "https://github.com/YdrMaster/learn-flash-attn", rev = "57176f5" }
7+
cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "803d64b" }
8+
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "803d64b" }
9+
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "803d64b" }
10+
flash-attn = { git = "https://github.com/YdrMaster/learn-flash-attn", rev = "1caeeee" }
1111
nn = { git = "https://github.com/YdrMaster/InfiniNN", rev = "171c5b0" }
1212
ggus = { git = "https://github.com/InfiniTensor/gguf", rev = "23c362f" }
1313
tokeneer = { git = "https://github.com/InfiniTensor/tokeneer", rev = "c48f39f" }
@@ -27,7 +27,7 @@ minijinja = { version = "2.11", default-features = false, features = [
2727

2828
[build-dependencies]
2929
build-script-cfg = "0.1"
30-
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "31c8090" }
31-
search-maca-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "31c8090" }
32-
search-corex-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "31c8090" }
33-
cuda-cc = { git = "https://github.com/YdrMaster/cuda-driver", rev = "31c8090" }
30+
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "803d64b" }
31+
search-maca-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "803d64b" }
32+
search-corex-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "803d64b" }
33+
cuda-cc = { git = "https://github.com/YdrMaster/cuda-driver", rev = "803d64b" }

llama.cu/src/exec/step.rs

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
utils::{destruct, distinct, offset_ptr, strides},
66
};
77
use cuda::{CaptureStream, GraphExec, Module, Stream, VirByte};
8-
use flash_attn::attention::{FlashAttnCfg, KVPage, KernelReq, Strides2D};
8+
use flash_attn::attention::{AttnType, FlashAttnCfg, KVPage, KernelReq, Strides2D};
99
use ggus::ggml_quants::f16;
1010
use nn::{Arg, Named, Tensor, digit_layout::types};
1111
use regex::Regex;
12-
use std::{fmt, sync::LazyLock};
12+
use std::{fmt, ptr::null, sync::LazyLock};
1313

1414
pub(super) enum Step<'ctx> {
1515
Graph(GraphExec<'ctx>, Box<[Tensor<*const VirByte, 2>]>),
@@ -144,31 +144,28 @@ impl<'ctx> Handle<'ctx> {
144144
reqs: &[Req<Tensor<*const VirByte, 2>>],
145145
stream: &Stream,
146146
) {
147+
use ::flash_attn::attention::cuda::code as flash_attn_code;
147148
let Attention { q, k, v, o, .. } = attn;
148149
let dt = distinct(&[q.dt(), k.dt(), v.dt(), o.dt()]).unwrap();
149150
// 编译
150151
let key = [ModuleKey::Text("flash-attn"), ModuleKey::Type(dt)].into_iter();
151-
let [t_compute, t_data] = match dt {
152-
types::F16 => ["float", "half"],
153-
_ => todo!(),
154-
};
155-
let module = self.compile(key.collect(), || {
156-
::flash_attn::attention::cuda::code(t_compute, t_data)
157-
});
158152
match dt {
159-
types::F16 => launch_attn_typed::<f16>(attn, reqs, module, stream),
153+
types::F16 => {
154+
let module = self.compile(key.collect(), || flash_attn_code::<f16>());
155+
launch_attn_typed::<f16>(attn, reqs, module, stream)
156+
}
160157
_ => todo!(),
161158
}
162159
}
163160
}
164161

165-
fn launch_attn_typed<T: Copy>(
162+
fn launch_attn_typed<T: ::flash_attn::attention::cuda::NVDT>(
166163
attn: &Attention,
167164
reqs: &[Req<Tensor<*const VirByte, 2>>],
168165
module: &Module,
169166
stream: &Stream,
170167
) {
171-
const TILE_SEQ: usize = 32;
168+
const TILE_SEQ: usize = 8;
172169
const TILE_CTX: usize = 32;
173170

174171
let Attention { iblk, q, k, v, o } = attn;
@@ -230,25 +227,10 @@ fn launch_attn_typed<T: Copy>(
230227
})
231228
})
232229
.collect::<Box<_>>();
233-
// 生成 mask
234-
let masks = reqs
235-
.iter()
236-
.map(|req| {
237-
let Req { pos, seq: n, .. } = req;
238-
let s = pos + n;
239-
let s_ceil = s.div_ceil(TILE_CTX) * TILE_CTX;
240-
// 注意力掩码
241-
let mask = (0..n * s_ceil)
242-
.map(|i| i % s_ceil <= s - n + i / s_ceil)
243-
.collect::<Box<_>>();
244-
stream.from_host(&mask)
245-
})
246-
.collect::<Box<_>>();
247230
// 为每个请求的每个头生成 block
248231
let reqs_ = reqs
249232
.iter()
250-
.zip(&masks)
251-
.scan((0, 0), |(seq, page), (req, mask)| {
233+
.scan((0, 0), |(seq, page), req| {
252234
let &Req {
253235
ref cache,
254236
pos,
@@ -288,9 +270,10 @@ fn launch_attn_typed<T: Copy>(
288270
kv_strides,
289271
o: offset_ptr(&o).cast_mut().cast(),
290272
o_strides,
291-
mask: mask.as_ptr().cast(),
292273
n,
293274
s: pos + n,
275+
ty: AttnType::Causal,
276+
mask: null(),
294277
})
295278
})
296279
.collect::<Box<_>>();

xtask/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ ratatui = "0.29"
2121
serde.workspace = true
2222
serde_json = "1.0"
2323
toml = "0.8"
24-
tokio = { version = "1.46", features = ["rt-multi-thread", "net"] }
24+
tokio = { version = "1.47", features = ["rt-multi-thread", "net"] }
2525
hyper = { version = "1.6", features = ["http1", "server"] }
2626
hyper-util = { version = "0.1", features = ["http1", "tokio", "server"] }
2727
http-body-util = "0.1"

xtask/src/bench.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ impl BenchArgs {
5454
let mut steps = 0;
5555
loop {
5656
let time = Instant::now();
57-
let Received { sessions, outputs } = service.recv(Duration::from_millis(100));
57+
let Received { sessions, outputs } = service.recv(Duration::MAX);
5858
let time = time.elapsed();
5959
println!("{steps:03}. time = {time:?}");
6060
steps += 1;

0 commit comments

Comments
 (0)