Skip to content

Commit 2d560ec

Browse files
committed
perf: 优化性能
Signed-off-by: YdrMaster <[email protected]>
1 parent 49bb7d8 commit 2d560ec

File tree

6 files changed

+66
-44
lines changed

6 files changed

+66
-44
lines changed

Cargo.lock

Lines changed: 39 additions & 21 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llama.cu/Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ version = "0.0.0"
44
edition.workspace = true
55

66
[dependencies]
7-
cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "803d64b" }
8-
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "803d64b" }
9-
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "803d64b" }
10-
flash-attn = { git = "https://github.com/YdrMaster/learn-flash-attn", rev = "1caeeee" }
7+
cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" }
8+
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" }
9+
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" }
10+
flash-attn = { git = "https://github.com/YdrMaster/learn-flash-attn", rev = "616bbac" }
1111
nn = { git = "https://github.com/YdrMaster/InfiniNN", rev = "171c5b0" }
1212
ggus = { git = "https://github.com/InfiniTensor/gguf", rev = "23c362f" }
1313
tokeneer = { git = "https://github.com/InfiniTensor/tokeneer", rev = "c48f39f" }

llama.cu/src/exec/step.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,19 @@ impl<'ctx> Handle<'ctx> {
147147
use ::flash_attn::attention::cuda::code as flash_attn_code;
148148
let Attention { q, k, v, o, .. } = attn;
149149
let dt = distinct(&[q.dt(), k.dt(), v.dt(), o.dt()]).unwrap();
150+
let d = *q.shape().last().unwrap();
150151
// 编译
151-
let key = [ModuleKey::Text("flash-attn"), ModuleKey::Type(dt)].into_iter();
152+
let key = [
153+
ModuleKey::Text("flash-attn"),
154+
ModuleKey::Type(dt),
155+
ModuleKey::Size(d),
156+
]
157+
.into_iter();
152158
match dt {
153159
types::F16 => {
154-
let module = self.compile(key.collect(), || flash_attn_code::<f16>());
160+
let module = self.compile(key.collect(), || {
161+
flash_attn_code::<f16>(d, stream.ctx().dev().warp_size())
162+
});
155163
launch_attn_typed::<f16>(attn, reqs, module, stream)
156164
}
157165
_ => todo!(),

llama.cu/src/handle.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::op::ModuleKey;
22
use cublas::Cublas;
3-
use cuda::{CurrentCtx, Module, Ptx};
3+
use cuda::{CurrentCtx, Module, Rtc};
44
use std::collections::HashMap;
55

66
#[cfg(nccl)]
@@ -37,9 +37,11 @@ impl<'ctx> Handle<'ctx> {
3737

3838
pub fn compile(&mut self, key: Box<[ModuleKey]>, code: impl FnOnce() -> String) -> &Module {
3939
self.modules.entry(key).or_insert_with(|| {
40-
let (ptx, log) = Ptx::compile(code(), self.ctx.dev().compute_capability());
41-
let Ok(ptx) = ptx else { panic!("{log}") };
42-
self.ctx.load(&ptx)
40+
let program = Rtc::new()
41+
.arch(self.ctx.dev().compute_capability())
42+
.compile(&code())
43+
.unwrap();
44+
self.ctx.load(&program)
4345
})
4446
}
4547

llama.cu/src/op/random_sample/modifier.rs

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
//! <https://zhuanlan.zhihu.com/p/667025336>
22
33
use crate::utils::offset_ptr;
4-
use cuda::{CurrentCtx, DevByte, DevMem, Module, Ptx, Stream, VirByte, params};
5-
use log::warn;
4+
use cuda::{CurrentCtx, DevByte, DevMem, Module, Rtc, Stream, VirByte, params};
65
use nn::Tensor;
76
use std::ffi::c_uint;
87
use tokeneer::utok;
@@ -77,15 +76,10 @@ extern "C" __global__ void next(
7776
next_kernel(logits, records, n, eos, temperature, penalty, tok);
7877
}}"#
7978
);
80-
let (ptx, log) = Ptx::compile(code, ctx.dev().compute_capability());
81-
match ptx {
82-
Ok(ptx) => {
83-
if !log.is_empty() {
84-
warn!("{log}")
85-
}
86-
ctx.load(&ptx)
87-
}
88-
Err(e) => panic!("logits modify compilation failed with {e:?}, log:\n {log}"),
89-
}
79+
let program = Rtc::new()
80+
.arch(ctx.dev().compute_capability())
81+
.compile(&code)
82+
.unwrap();
83+
ctx.load(&program)
9084
}
9185
}

xtask/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ fn main() {
3535

3636
#[derive(Parser)]
3737
#[clap(name = "InfiniLM")]
38-
#[clap(version, about, long_about = None)]
38+
#[clap(version, long_about = None)]
3939
struct Cli {
4040
#[clap(subcommand)]
4141
command: Commands,

0 commit comments

Comments
 (0)