Skip to content

Commit 9bc8178

Browse files
committed
temp
1 parent 71c1e35 commit 9bc8178

File tree

9 files changed

+551
-9
lines changed

9 files changed

+551
-9
lines changed

Cargo.lock

Lines changed: 0 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

InfiniLM

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 9eef45af70e2d65cf96373d79bdabb9b38003e45

llama.cu/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" }
88
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" }
99
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" }
1010
flash-attn = { git = "https://github.com/YdrMaster/learn-flash-attn", rev = "616bbac" }
11-
nn = { git = "https://github.com/YdrMaster/InfiniNN", rev = "fa8aaf6" }
11+
nn = { path = "/home/cearx/qy/InfiniNN/1_nn" }
1212
ggus = { git = "https://github.com/InfiniTensor/gguf", rev = "23c362f" }
13-
tokeneer = { git = "https://github.com/InfiniTensor/tokeneer", rev = "c48f39f" }
13+
#tokeneer = { git = "https://github.com/InfiniTensor/tokeneer", rev = "c48f39f" }
14+
tokeneer = {path = "/home/cearx/qy/tokeneer"}
1415

1516
bytesize = "2.0"
1617
log.workspace = true

llama.cu/src/exec/group.rs

Lines changed: 201 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
use super::mamba_cache::MambaCache;
12
use super::{CacheParts, Progress, model::ModelExec, upos};
23
use crate::{batch::Req, handle::Handle, load::load_weight, memory::MemPages};
34
use cuda::{DevByte, DevMem, Stream, VirByte};
45
use nn::{
5-
Distribution, Graph, GraphBuilder, LLaMA, NNGraph, Tensor, TensorMeta, digit_layout::types, op,
6+
Distribution, Graph, GraphBuilder, LLaMA, Mamba, NNGraph, Tensor, TensorMeta,
7+
digit_layout::types, op,
68
};
79
use std::{
810
collections::BTreeMap,
@@ -233,10 +235,208 @@ fn builder() -> GraphBuilder {
233235
.register_op("rope", op::rope::Rope)
234236
.register_op("attention", op::attention::Attention)
235237
.register_op("swiglu", op::activation::SwiGLU)
238+
.register_op("silu", op::activation::SiLU)
236239
.register_op("concat", op::concat::Concat)
237240
.register_op("split", op::split::Split)
238241
.register_op("tile", op::tile::Tile)
239242
.register_op("merge", op::merge::Merge)
240243
.register_op("all-reduce", op::all_reduce::AllReduce);
241244
ans
242245
}
246+
247+
// 针对 Mamba 的 GraphBuilder(注册其所需算子)
248+
fn builder_mamba() -> GraphBuilder {
249+
let mut ans = GraphBuilder::default();
250+
ans.register_op("embedding", op::embedding::Embedding)
251+
.register_op("rms-norm", op::normalization::RmsNorm)
252+
.register_op("linear", op::linear::Linear)
253+
.register_op("silu", op::activation::SiLU)
254+
.register_op("element-mul", op::element_mul::ElementMul)
255+
.register_op("split", op::split::Split)
256+
.register_op("mamba-causal-conv1d", op::mamba::CausalConv1d)
257+
.register_op("mamba-selective-scan", op::mamba::SelectiveScan);
258+
ans
259+
}
260+
261+
// Mamba 推理组:仅 token 输入,无 KV cache 参与
262+
pub(crate) struct ModelGroupMamba<'ctx> {
263+
internal: Internal<'ctx>,
264+
pages: MemPages,
265+
_weight: DevMem<'ctx>,
266+
// 下一次写入到 pos_buf 的位置(用于单步增量 decode/prefill)
267+
next_pos: u32,
268+
}
269+
270+
impl<'ctx> ModelGroupMamba<'ctx> {
271+
pub fn new<T: IntoIterator<Item = usize>>(
272+
mamba: Mamba<Tensor<&[u8], 2>>,
273+
dist: Distribution,
274+
progress: Option<Arc<Progress>>, // 预留
275+
config: ModelGroupConfig<T>,
276+
handle: &mut Handle<'ctx>,
277+
barrier: Option<&Barrier>,
278+
) -> Self {
279+
let ModelGroupConfig {
280+
static_model_keys,
281+
mut dyn_cache_size,
282+
use_cuda_graph,
283+
} = config;
284+
285+
let NNGraph(Graph { topo, nodes, edges }) = builder_mamba()
286+
.build(
287+
mamba.tensor_parallel(dist),
288+
[
289+
TensorMeta::new(types::U32, ["n_tok".into()]),
290+
TensorMeta::new(types::U32, ["n_tok".into()]),
291+
TensorMeta::new(types::U32, ["n_tok".into()]),
292+
],
293+
)
294+
.unwrap();
295+
handle.ctx.stream().synchronize();
296+
297+
let dev = handle.ctx.dev();
298+
let mut pages = MemPages::new(dev);
299+
let (_weight, edges) = load_weight(edges, progress, handle.ctx);
300+
let graph = NNGraph(Graph { topo, nodes, edges });
301+
let static_models = if use_cuda_graph {
302+
static_model_keys
303+
.into_iter()
304+
.map(|n_tok| {
305+
if let Some(b) = barrier {
306+
b.wait();
307+
}
308+
let key = NonZeroUsize::new(n_tok).unwrap();
309+
let exec = ModelExec::new(graph.clone(), n_tok, handle, &mut pages, true);
310+
(key, exec)
311+
})
312+
.collect::<BTreeMap<_, _>>()
313+
} else {
314+
dyn_cache_size += static_model_keys.into_iter().count();
315+
Default::default()
316+
};
317+
318+
let internal = Internal::new(graph, static_models, dyn_cache_size);
319+
Self {
320+
internal,
321+
pages,
322+
_weight,
323+
next_pos: 0,
324+
}
325+
}
326+
327+
pub fn load_inputs_mamba(
328+
&mut self,
329+
handle: &mut Handle<'ctx>,
330+
len: usize,
331+
tok: &[utok],
332+
stream: &Stream<'ctx>,
333+
) -> (NonZeroUsize, &mut [DevByte]) {
334+
let key = self.internal.get_key(NonZeroUsize::new(len).unwrap());
335+
let model = self.internal.map_exec(key, handle, &mut self.pages, stream);
336+
stream.memcpy_h2d(model.tok_buf(), &tok[..key.get()]);
337+
let pos: Vec<upos> = (0..key.get()).map(|i| i as upos).collect();
338+
stream.memcpy_h2d(model.pos_buf(), &pos);
339+
// 将 next_pos 对齐到 prefill 末尾,便于后续 decode 递增
340+
self.next_pos = key.get() as u32;
341+
// out_idx:prefill 阶段对所有位置计算输出头
342+
let out_idx: Vec<utok> = (0..key.get()).map(|i| i as utok).collect();
343+
let buf = model.input_buf_at(2);
344+
stream.memcpy_h2d(buf, &out_idx);
345+
(key, model.tok_buf())
346+
}
347+
348+
#[cfg(nccl)]
349+
pub fn share_inputs(
350+
&mut self,
351+
key: NonZeroUsize,
352+
handle: &mut Handle<'ctx>,
353+
stream: &Stream<'ctx>,
354+
) {
355+
let model = self.internal.map_exec(key, handle, &mut self.pages, stream);
356+
if let Some(comm) = &handle.comm {
357+
comm.broadcast(model.tok_buf(), None, 0, stream);
358+
}
359+
}
360+
361+
pub fn launch(
362+
&mut self,
363+
key: NonZeroUsize,
364+
handle: &mut Handle,
365+
stream: &Stream<'ctx>,
366+
) -> Tensor<*const VirByte, 2> {
367+
self.internal
368+
.get_mut(&key)
369+
.unwrap()
370+
.launch(handle, &[], stream)
371+
}
372+
373+
/// 单步增量:加载单 token(pos 固定为 0,out_idx 固定为 0)
374+
pub fn append_input_mamba(
375+
&mut self,
376+
handle: &mut Handle<'ctx>,
377+
tok: utok,
378+
stream: &Stream<'ctx>,
379+
) -> (NonZeroUsize, &mut [DevByte]) {
380+
// 使用 n_tok = 1 的模型
381+
let key = self.internal.get_key(NonZeroUsize::new(1).unwrap());
382+
let model = self.internal.map_exec(key, handle, &mut self.pages, stream);
383+
// tok
384+
let tok_buf = model.tok_buf();
385+
stream.memcpy_h2d(tok_buf, &[tok]);
386+
// pos 递增(prefill 从 0 开始,decode 从 prefill 末尾继续)
387+
let pos_buf = model.pos_buf();
388+
let cur = self.next_pos;
389+
stream.memcpy_h2d(pos_buf, &[cur]);
390+
self.next_pos = cur.saturating_add(1);
391+
// out_idx 固定为 0
392+
let out_idx_buf = model.input_buf_at(2);
393+
stream.memcpy_h2d(out_idx_buf, &[0u32]);
394+
(key, model.tok_buf())
395+
}
396+
397+
/// 设置下一步写入的起始位置(用于显式对齐 prefill→decode)
398+
pub fn set_decode_start_pos(&mut self, start: u32) {
399+
self.next_pos = start;
400+
}
401+
402+
/// 单步增量:执行一步,返回隐藏态(仅一个位置)
403+
pub fn launch_step(
404+
&mut self,
405+
key: NonZeroUsize,
406+
handle: &mut Handle,
407+
stream: &Stream<'ctx>,
408+
) -> Tensor<*const VirByte, 2> {
409+
// 目前复用现有图执行路径(n_tok=1),后续接入 Step::Mamba 内核以就地更新状态
410+
self.internal
411+
.get_mut(&key)
412+
.unwrap()
413+
.launch(handle, &[], stream)
414+
}
415+
416+
/// 单步增量(预留):带 mamba cache 的版本
417+
pub fn launch_step_with_cache(
418+
&mut self,
419+
key: NonZeroUsize,
420+
cache: &mut MambaCache,
421+
cache_pos: usize,
422+
handle: &mut Handle,
423+
stream: &Stream<'ctx>,
424+
) -> Tensor<*const VirByte, 2> {
425+
// 切换到专用单步:只替换 Mamba Step 的执行,其他仍按原图执行
426+
let model = self.internal.get_mut(&key).unwrap();
427+
model.launch_with_mamba_cache(handle, cache, cache_pos, stream)
428+
}
429+
430+
/// Prefill 阶段:带 mamba cache 的版本,用于执行 prefill 并写回状态
431+
pub fn launch_prefill_with_cache(
432+
&mut self,
433+
key: NonZeroUsize,
434+
cache: &mut MambaCache,
435+
handle: &mut Handle,
436+
stream: &Stream<'ctx>,
437+
) -> Tensor<*const VirByte, 2> {
438+
// 切换到专用 prefill:只替换 Mamba Step 的执行,其他仍按原图执行
439+
let model = self.internal.get_mut(&key).unwrap();
440+
model.launch_with_mamba_cache(handle, cache, 0, stream)
441+
}
442+
}

0 commit comments

Comments
 (0)