Skip to content

Commit 495e4b8

Browse files
committed
feat(llama.cu): 支持推理 mamba 模型
1 parent 71c1e35 commit 495e4b8

File tree

19 files changed

+1543
-10
lines changed

19 files changed

+1543
-10
lines changed

Cargo.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llama.cu/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" }
88
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" }
99
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" }
1010
flash-attn = { git = "https://github.com/YdrMaster/learn-flash-attn", rev = "616bbac" }
11-
nn = { git = "https://github.com/YdrMaster/InfiniNN", rev = "fa8aaf6" }
11+
nn = { git = "https://github.com/CearX/InfiniNN", rev = "6caef2"}
1212
ggus = { git = "https://github.com/InfiniTensor/gguf", rev = "23c362f" }
13-
tokeneer = { git = "https://github.com/InfiniTensor/tokeneer", rev = "c48f39f" }
13+
tokeneer = { git = "https://github.com/CearX/tokeneer.git", rev = "2546d72" }
1414

1515
bytesize = "2.0"
1616
log.workspace = true

llama.cu/src/exec/group.rs

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
use super::mamba_cache::MambaCache;
12
use super::{CacheParts, Progress, model::ModelExec, upos};
23
use crate::{batch::Req, handle::Handle, load::load_weight, memory::MemPages};
34
use cuda::{DevByte, DevMem, Stream, VirByte};
45
use nn::{
5-
Distribution, Graph, GraphBuilder, LLaMA, NNGraph, Tensor, TensorMeta, digit_layout::types, op,
6+
Distribution, Graph, GraphBuilder, LLaMA, Mamba, NNGraph, Tensor, TensorMeta,
7+
digit_layout::types, op,
68
};
79
use std::{
810
collections::BTreeMap,
@@ -240,3 +242,133 @@ fn builder() -> GraphBuilder {
240242
.register_op("all-reduce", op::all_reduce::AllReduce);
241243
ans
242244
}
245+
246+
// Mamba GraphBuilder
247+
fn builder_mamba() -> GraphBuilder {
248+
let mut ans = GraphBuilder::default();
249+
ans.register_op("embedding", op::embedding::Embedding)
250+
.register_op("rms-norm", op::normalization::RmsNorm)
251+
.register_op("linear", op::linear::Linear)
252+
.register_op("silu", op::activation::SiLU)
253+
.register_op("element-mul", op::element_mul::ElementMul)
254+
.register_op("split", op::split::Split)
255+
.register_op("mamba-causal-conv1d", op::mamba::CausalConv1d)
256+
.register_op("mamba-selective-scan", op::mamba::SelectiveScan);
257+
ans
258+
}
259+
260+
pub(crate) struct ModelGroupMamba<'ctx> {
261+
internal: Internal<'ctx>,
262+
pages: MemPages,
263+
_weight: DevMem<'ctx>,
264+
next_pos: u32,
265+
}
266+
267+
impl<'ctx> ModelGroupMamba<'ctx> {
268+
pub fn new<T: IntoIterator<Item = usize>>(
269+
mamba: Mamba<Tensor<&[u8], 2>>,
270+
dist: Distribution,
271+
progress: Option<Arc<Progress>>,
272+
config: ModelGroupConfig<T>,
273+
handle: &mut Handle<'ctx>,
274+
barrier: Option<&Barrier>,
275+
) -> Self {
276+
let ModelGroupConfig {
277+
static_model_keys,
278+
mut dyn_cache_size,
279+
use_cuda_graph,
280+
} = config;
281+
282+
let NNGraph(Graph { topo, nodes, edges }) = builder_mamba()
283+
.build(
284+
mamba.tensor_parallel(dist),
285+
[
286+
TensorMeta::new(types::U32, ["n_tok".into()]),
287+
TensorMeta::new(types::U32, ["n_tok".into()]),
288+
TensorMeta::new(types::U32, ["n_tok".into()]),
289+
],
290+
)
291+
.unwrap();
292+
handle.ctx.stream().synchronize();
293+
294+
let dev = handle.ctx.dev();
295+
let mut pages = MemPages::new(dev);
296+
let (_weight, edges) = load_weight(edges, progress, handle.ctx);
297+
let graph = NNGraph(Graph { topo, nodes, edges });
298+
let static_models = if use_cuda_graph {
299+
static_model_keys
300+
.into_iter()
301+
.map(|n_tok| {
302+
if let Some(b) = barrier {
303+
b.wait();
304+
}
305+
let key = NonZeroUsize::new(n_tok).unwrap();
306+
let exec = ModelExec::new(graph.clone(), n_tok, handle, &mut pages, true);
307+
(key, exec)
308+
})
309+
.collect::<BTreeMap<_, _>>()
310+
} else {
311+
dyn_cache_size += static_model_keys.into_iter().count();
312+
Default::default()
313+
};
314+
315+
let internal = Internal::new(graph, static_models, dyn_cache_size);
316+
Self {
317+
internal,
318+
pages,
319+
_weight,
320+
next_pos: 0,
321+
}
322+
}
323+
324+
pub fn load_inputs_mamba_prefill(
325+
&mut self,
326+
handle: &mut Handle<'ctx>,
327+
len: usize,
328+
tok: &[utok],
329+
stream: &Stream<'ctx>,
330+
) -> (NonZeroUsize, &mut [DevByte]) {
331+
let key = self.internal.get_key(NonZeroUsize::new(len).unwrap());
332+
let model = self.internal.map_exec(key, handle, &mut self.pages, stream);
333+
stream.memcpy_h2d(model.tok_buf(), &tok[..key.get()]);
334+
let pos: Vec<upos> = (0..key.get()).map(|i| i as upos).collect();
335+
stream.memcpy_h2d(model.pos_buf(), &pos);
336+
self.next_pos = key.get() as u32;
337+
let out_idx: Vec<utok> = (0..key.get()).map(|i| i as utok).collect();
338+
let buf = model.input_buf_at(2);
339+
stream.memcpy_h2d(buf, &out_idx);
340+
(key, model.tok_buf())
341+
}
342+
343+
pub fn load_input_mamba_decode(
344+
&mut self,
345+
handle: &mut Handle<'ctx>,
346+
tok: utok,
347+
stream: &Stream<'ctx>,
348+
) -> (NonZeroUsize, &mut [DevByte]) {
349+
let key = self.internal.get_key(NonZeroUsize::new(1).unwrap());
350+
let model = self.internal.map_exec(key, handle, &mut self.pages, stream);
351+
let tok_buf = model.tok_buf();
352+
stream.memcpy_h2d(tok_buf, &[tok]);
353+
let pos_buf = model.pos_buf();
354+
let cur = self.next_pos;
355+
stream.memcpy_h2d(pos_buf, &[cur]);
356+
// 更新 next_pos
357+
self.next_pos = cur.saturating_add(1);
358+
// decode 时 out_idx 固定为 0
359+
let out_idx_buf = model.input_buf_at(2);
360+
stream.memcpy_h2d(out_idx_buf, &[0u32]);
361+
(key, model.tok_buf())
362+
}
363+
364+
pub fn launch_mamba(
365+
&mut self,
366+
key: NonZeroUsize,
367+
cache: &mut MambaCache,
368+
handle: &mut Handle,
369+
stream: &Stream<'ctx>,
370+
) -> Tensor<*const VirByte, 2> {
371+
let model = self.internal.get_mut(&key).unwrap();
372+
model.launch_with_mamba_cache(handle, cache, stream)
373+
}
374+
}

llama.cu/src/exec/mamba.rs

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
use crate::exec::group::{ModelGroupConfig, ModelGroupMamba};
2+
use crate::exec::mamba_cache::MambaCache;
3+
use crate::exec::output_head::OutputHead;
4+
use crate::exec::sample_manager::SampleManager;
5+
use crate::memory::MemPages;
6+
use crate::op::random_sample::{KVPair, SampleArgs};
7+
use crate::utils::{self, meta};
8+
use crate::{handle::Handle, model::map_files};
9+
use cuda::Device;
10+
use ggus::GGufMetaMapExt;
11+
use nn::Distribution;
12+
use std::env;
13+
use tokeneer::Bpe;
14+
15+
#[allow(dead_code)]
16+
pub fn mamba_infer(
17+
model_path: std::path::PathBuf,
18+
text: &str,
19+
use_cuda_graph: bool,
20+
) -> (Vec<u8>, [usize; 2]) {
21+
use crate::model::GGufModel;
22+
// 初始化 CUDA
23+
assert!(cuda::init().is_ok());
24+
25+
// 加载模型
26+
let maps = map_files(model_path);
27+
let gguf = GGufModel::read(maps.iter().map(|x| &**x));
28+
let tokenizer = Bpe::from_gguf(&gguf);
29+
let mut tokens = tokenizer.encode(text);
30+
31+
let n_tok = tokens.len();
32+
33+
// 取出输出头用于 logits 计算
34+
let mut mamba = gguf.mamba();
35+
let output_head_nn = mamba
36+
.output_head
37+
.take()
38+
.expect("mamba model missing output_head");
39+
40+
let n_layer: usize = meta![gguf => llm_block_count];
41+
let d_inner: usize = 5120; // TODO: ggus
42+
let d_conv: usize = 4; // kernel size
43+
let d_state: usize = 16; // ssm state size
44+
45+
// 单卡
46+
let device = Device::new(0);
47+
device.retain_primary().apply(|ctx| {
48+
let mut handle = Handle::new(ctx);
49+
let dist = Distribution {
50+
start: 0,
51+
len: 1,
52+
total: 1,
53+
};
54+
let mut models = ModelGroupMamba::new(
55+
mamba,
56+
dist,
57+
None,
58+
ModelGroupConfig {
59+
static_model_keys: [n_tok],
60+
dyn_cache_size: 1,
61+
use_cuda_graph,
62+
},
63+
&mut handle,
64+
None,
65+
);
66+
67+
// 组件:输出头与采样器
68+
let stream = ctx.stream();
69+
let mut output_head = OutputHead::new(output_head_nn, ctx);
70+
71+
// 读取词表大小构建采样器与 eos
72+
let eos: tokeneer::utok = meta![gguf => tokenizer_ggml_eos_token_id];
73+
let nvoc = output_head.nvoc();
74+
let mut sample_manager = SampleManager::new(nvoc, eos, ctx);
75+
76+
// 初始化 MambaCache
77+
let mut pages = MemPages::new(device);
78+
let mut mcache = MambaCache::new(n_layer, d_inner, d_conv, d_state, &mut pages);
79+
80+
// Prefill
81+
let (key, _tok_buf) =
82+
models.load_inputs_mamba_prefill(&mut handle, tokens.len(), &tokens, &stream);
83+
84+
let mut x = models.launch_mamba(key, &mut mcache, &mut handle, &stream);
85+
86+
let last_idx: [tokeneer::utok; 1] = [(tokens.len() - 1) as tokeneer::utok];
87+
let logits_prefill_last = output_head.launch(x.clone(), &last_idx, &mut handle, &stream);
88+
89+
let logits_prefill_last_vir = logits_prefill_last.as_ref().map(|mem| mem.as_ptr().cast());
90+
utils::fmt(&logits_prefill_last_vir, stream.ctx());
91+
// check prefill logits
92+
93+
let mut next_id: tokeneer::utok;
94+
{
95+
let mut input = stream.malloc::<tokeneer::utok>(tokens.len());
96+
stream.memcpy_h2d(&mut input, &tokens);
97+
let cfg0 = vec![(
98+
crate::batch::SessionId(0),
99+
crate::batch::SampleInfo {
100+
args: SampleArgs::new(0.8, 0.95, 50, 1.2).unwrap(),
101+
input_idx: tokens.len(),
102+
decode_len: tokens.len(),
103+
},
104+
)];
105+
let kv_pairs0 = sample_manager.sample(logits_prefill_last, &input, &cfg0, &stream);
106+
stream.free(input);
107+
let mut host_kv0 = vec![KVPair::ZERO; 1];
108+
stream.memcpy_d2h(&mut host_kv0, &kv_pairs0).free(kv_pairs0);
109+
next_id = host_kv0[0].idx as tokeneer::utok;
110+
}
111+
let mut generated: Vec<tokeneer::utok> = Vec::new();
112+
if next_id != eos {
113+
tokens.push(next_id);
114+
generated.push(next_id);
115+
let (key, _tok_buf) = models.load_input_mamba_decode(&mut handle, next_id, &stream);
116+
x = models.launch_mamba(key, &mut mcache, &mut handle, &stream);
117+
}
118+
119+
let max_decode_steps: usize = env::var("MAMBA_STEPS")
120+
.ok()
121+
.and_then(|s| s.parse().ok())
122+
.unwrap_or(100);
123+
for _step in 1..max_decode_steps {
124+
let out_idx: [tokeneer::utok; 1] = [0];
125+
126+
let logits = output_head.launch(x.clone(), &out_idx, &mut handle, &stream);
127+
128+
let mut input = stream.malloc::<tokeneer::utok>(tokens.len());
129+
stream.memcpy_h2d(&mut input, &tokens);
130+
let cfg = vec![(
131+
crate::batch::SessionId(0),
132+
crate::batch::SampleInfo {
133+
args: SampleArgs::new(0.8, 0.95, 50, 1.2).unwrap(),
134+
input_idx: tokens.len(),
135+
decode_len: tokens.len(),
136+
},
137+
)];
138+
let kv_pairs = sample_manager.sample(logits, &input, &cfg, &stream);
139+
stream.free(input);
140+
let mut host_kv = vec![KVPair::ZERO; 1];
141+
stream.memcpy_d2h(&mut host_kv, &kv_pairs).free(kv_pairs);
142+
next_id = host_kv[0].idx as tokeneer::utok;
143+
144+
if next_id == eos {
145+
break;
146+
}
147+
148+
tokens.push(next_id);
149+
generated.push(next_id);
150+
let (key, _tok_buf) = models.load_input_mamba_decode(&mut handle, next_id, &stream);
151+
x = models.launch_mamba(key, &mut mcache, &mut handle, &stream);
152+
}
153+
154+
println!("tokens = {:?}", tokens);
155+
let mut text_buf = tokeneer::TextBuf::new();
156+
let s = tokenizer.decode(&generated, &mut text_buf);
157+
let buf = s.into_bytes();
158+
159+
let shape = <[usize; 2]>::try_from(x.shape().to_vec()).unwrap();
160+
(buf, shape)
161+
})
162+
}
163+
164+
// #[cfg(test)]
165+
// mod tests {
166+
// use super::*;
167+
// use std::path::PathBuf;
168+
169+
// #[test]
170+
// fn test_mamba_infer_decode() {
171+
// let model = PathBuf::from("../model/Mamba_adf32-2.8B-hf-v1.0-F16.gguf");
172+
// let prompt = "Once upon a time,";
173+
// let (bytes, _shape) = mamba_infer(model, prompt, false);
174+
// let text = String::from_utf8_lossy(&bytes);
175+
// println!("prompt = {}", prompt);
176+
// println!("mamba infer text = {}", text);
177+
// }
178+
// }

0 commit comments

Comments
 (0)