Skip to content

Commit e73b0fb

Browse files
committed
fix(llama-cuda): 支持根据空闲内存计算可能的 kv cache 容量
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 5622cb8 commit e73b0fb

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ itertools = "0.13"
3434
env_logger = "0.11"
3535
build-script-cfg = "0.0"
3636

37-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "7886d54", default-features = false }
37+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "359b86a", default-features = false }
3838

3939
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "f69b160" }
4040
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "e8362c3" }

models/llama/common/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ impl LlamaMeta {
106106
Tensor::new(dt_embd, &[buf, nblk, 2, nkvh, dh])
107107
}
108108

109+
pub fn kv_cache_in_size(&self, max: usize, size: usize) -> Tensor<usize> {
110+
self.kv_cache((size / self.kv_cache(1).take()).min(max))
111+
}
112+
109113
pub fn embd(&self, nt: usize) -> Tensor<usize> {
110114
let &Self { dt_embd, d, .. } = self;
111115
Tensor::new(dt_embd, &[nt, d])

models/llama/cuda/src/infer.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,20 @@ fn test_infer() {
7272
println!("load weights: {:?}", time.elapsed());
7373

7474
let (free, _) = ctx.mem_info();
75-
let queue_alloc = StreamMemPool::new(stream);
76-
queue_alloc.put((free.0 >> 30) << 30);
75+
let mut cache = meta
76+
// 用剩余空闲空间的一半存储 kv cache
77+
.kv_cache_in_size(nctx, free.0 / 2)
78+
.map(|len| ctx.malloc::<u8>(len));
79+
println!("cache len = {}", cache.shape()[0]);
7780

81+
let queue_alloc = StreamMemPool::new(stream);
7882
let alloc = |size| -> MemPoolBlob { queue_alloc.alloc(size) };
7983

84+
let (free, _) = ctx.mem_info();
85+
// 去除 64MiB 以下的零头
86+
queue_alloc.put(free.0 & !((64 << 20) - 1));
87+
8088
let mut worker = Worker::new(0, &gpu, meta.clone(), weights);
81-
let mut cache = meta.kv_cache(nctx).map(alloc);
8289
let sin_cos =
8390
<Operators as llama::Operators>::build_sin_cos(dt_embd, nctx, dh, &queue_alloc);
8491
let indices = RandomSample::build_indices(nvoc, &queue_alloc);

0 commit comments

Comments
 (0)