|
| 1 | +use super::mamba_cache::MambaCache; |
1 | 2 | use super::{CacheParts, Progress, model::ModelExec, upos}; |
2 | 3 | use crate::{batch::Req, handle::Handle, load::load_weight, memory::MemPages}; |
3 | 4 | use cuda::{DevByte, DevMem, Stream, VirByte}; |
4 | 5 | use nn::{ |
5 | | - Distribution, Graph, GraphBuilder, LLaMA, NNGraph, Tensor, TensorMeta, digit_layout::types, op, |
| 6 | + Distribution, Graph, GraphBuilder, LLaMA, Mamba, NNGraph, Tensor, TensorMeta, |
| 7 | + digit_layout::types, op, |
6 | 8 | }; |
7 | 9 | use std::{ |
8 | 10 | collections::BTreeMap, |
@@ -233,10 +235,208 @@ fn builder() -> GraphBuilder { |
233 | 235 | .register_op("rope", op::rope::Rope) |
234 | 236 | .register_op("attention", op::attention::Attention) |
235 | 237 | .register_op("swiglu", op::activation::SwiGLU) |
| 238 | + .register_op("silu", op::activation::SiLU) |
236 | 239 | .register_op("concat", op::concat::Concat) |
237 | 240 | .register_op("split", op::split::Split) |
238 | 241 | .register_op("tile", op::tile::Tile) |
239 | 242 | .register_op("merge", op::merge::Merge) |
240 | 243 | .register_op("all-reduce", op::all_reduce::AllReduce); |
241 | 244 | ans |
242 | 245 | } |
| 246 | + |
| 247 | +// 针对 Mamba 的 GraphBuilder(注册其所需算子) |
| 248 | +fn builder_mamba() -> GraphBuilder { |
| 249 | + let mut ans = GraphBuilder::default(); |
| 250 | + ans.register_op("embedding", op::embedding::Embedding) |
| 251 | + .register_op("rms-norm", op::normalization::RmsNorm) |
| 252 | + .register_op("linear", op::linear::Linear) |
| 253 | + .register_op("silu", op::activation::SiLU) |
| 254 | + .register_op("element-mul", op::element_mul::ElementMul) |
| 255 | + .register_op("split", op::split::Split) |
| 256 | + .register_op("mamba-causal-conv1d", op::mamba::CausalConv1d) |
| 257 | + .register_op("mamba-selective-scan", op::mamba::SelectiveScan); |
| 258 | + ans |
| 259 | +} |
| 260 | + |
| 261 | +// Mamba 推理组:仅 token 输入,无 KV cache 参与 |
| 262 | +pub(crate) struct ModelGroupMamba<'ctx> { |
| 263 | + internal: Internal<'ctx>, |
| 264 | + pages: MemPages, |
| 265 | + _weight: DevMem<'ctx>, |
| 266 | + // 下一次写入到 pos_buf 的位置(用于单步增量 decode/prefill) |
| 267 | + next_pos: u32, |
| 268 | +} |
| 269 | + |
| 270 | +impl<'ctx> ModelGroupMamba<'ctx> { |
| 271 | + pub fn new<T: IntoIterator<Item = usize>>( |
| 272 | + mamba: Mamba<Tensor<&[u8], 2>>, |
| 273 | + dist: Distribution, |
| 274 | + progress: Option<Arc<Progress>>, // 预留 |
| 275 | + config: ModelGroupConfig<T>, |
| 276 | + handle: &mut Handle<'ctx>, |
| 277 | + barrier: Option<&Barrier>, |
| 278 | + ) -> Self { |
| 279 | + let ModelGroupConfig { |
| 280 | + static_model_keys, |
| 281 | + mut dyn_cache_size, |
| 282 | + use_cuda_graph, |
| 283 | + } = config; |
| 284 | + |
| 285 | + let NNGraph(Graph { topo, nodes, edges }) = builder_mamba() |
| 286 | + .build( |
| 287 | + mamba.tensor_parallel(dist), |
| 288 | + [ |
| 289 | + TensorMeta::new(types::U32, ["n_tok".into()]), |
| 290 | + TensorMeta::new(types::U32, ["n_tok".into()]), |
| 291 | + TensorMeta::new(types::U32, ["n_tok".into()]), |
| 292 | + ], |
| 293 | + ) |
| 294 | + .unwrap(); |
| 295 | + handle.ctx.stream().synchronize(); |
| 296 | + |
| 297 | + let dev = handle.ctx.dev(); |
| 298 | + let mut pages = MemPages::new(dev); |
| 299 | + let (_weight, edges) = load_weight(edges, progress, handle.ctx); |
| 300 | + let graph = NNGraph(Graph { topo, nodes, edges }); |
| 301 | + let static_models = if use_cuda_graph { |
| 302 | + static_model_keys |
| 303 | + .into_iter() |
| 304 | + .map(|n_tok| { |
| 305 | + if let Some(b) = barrier { |
| 306 | + b.wait(); |
| 307 | + } |
| 308 | + let key = NonZeroUsize::new(n_tok).unwrap(); |
| 309 | + let exec = ModelExec::new(graph.clone(), n_tok, handle, &mut pages, true); |
| 310 | + (key, exec) |
| 311 | + }) |
| 312 | + .collect::<BTreeMap<_, _>>() |
| 313 | + } else { |
| 314 | + dyn_cache_size += static_model_keys.into_iter().count(); |
| 315 | + Default::default() |
| 316 | + }; |
| 317 | + |
| 318 | + let internal = Internal::new(graph, static_models, dyn_cache_size); |
| 319 | + Self { |
| 320 | + internal, |
| 321 | + pages, |
| 322 | + _weight, |
| 323 | + next_pos: 0, |
| 324 | + } |
| 325 | + } |
| 326 | + |
| 327 | + pub fn load_inputs_mamba( |
| 328 | + &mut self, |
| 329 | + handle: &mut Handle<'ctx>, |
| 330 | + len: usize, |
| 331 | + tok: &[utok], |
| 332 | + stream: &Stream<'ctx>, |
| 333 | + ) -> (NonZeroUsize, &mut [DevByte]) { |
| 334 | + let key = self.internal.get_key(NonZeroUsize::new(len).unwrap()); |
| 335 | + let model = self.internal.map_exec(key, handle, &mut self.pages, stream); |
| 336 | + stream.memcpy_h2d(model.tok_buf(), &tok[..key.get()]); |
| 337 | + let pos: Vec<upos> = (0..key.get()).map(|i| i as upos).collect(); |
| 338 | + stream.memcpy_h2d(model.pos_buf(), &pos); |
| 339 | + // 将 next_pos 对齐到 prefill 末尾,便于后续 decode 递增 |
| 340 | + self.next_pos = key.get() as u32; |
| 341 | + // out_idx:prefill 阶段对所有位置计算输出头 |
| 342 | + let out_idx: Vec<utok> = (0..key.get()).map(|i| i as utok).collect(); |
| 343 | + let buf = model.input_buf_at(2); |
| 344 | + stream.memcpy_h2d(buf, &out_idx); |
| 345 | + (key, model.tok_buf()) |
| 346 | + } |
| 347 | + |
| 348 | + #[cfg(nccl)] |
| 349 | + pub fn share_inputs( |
| 350 | + &mut self, |
| 351 | + key: NonZeroUsize, |
| 352 | + handle: &mut Handle<'ctx>, |
| 353 | + stream: &Stream<'ctx>, |
| 354 | + ) { |
| 355 | + let model = self.internal.map_exec(key, handle, &mut self.pages, stream); |
| 356 | + if let Some(comm) = &handle.comm { |
| 357 | + comm.broadcast(model.tok_buf(), None, 0, stream); |
| 358 | + } |
| 359 | + } |
| 360 | + |
| 361 | + pub fn launch( |
| 362 | + &mut self, |
| 363 | + key: NonZeroUsize, |
| 364 | + handle: &mut Handle, |
| 365 | + stream: &Stream<'ctx>, |
| 366 | + ) -> Tensor<*const VirByte, 2> { |
| 367 | + self.internal |
| 368 | + .get_mut(&key) |
| 369 | + .unwrap() |
| 370 | + .launch(handle, &[], stream) |
| 371 | + } |
| 372 | + |
| 373 | + /// 单步增量:加载单 token(pos 固定为 0,out_idx 固定为 0) |
| 374 | + pub fn append_input_mamba( |
| 375 | + &mut self, |
| 376 | + handle: &mut Handle<'ctx>, |
| 377 | + tok: utok, |
| 378 | + stream: &Stream<'ctx>, |
| 379 | + ) -> (NonZeroUsize, &mut [DevByte]) { |
| 380 | + // 使用 n_tok = 1 的模型 |
| 381 | + let key = self.internal.get_key(NonZeroUsize::new(1).unwrap()); |
| 382 | + let model = self.internal.map_exec(key, handle, &mut self.pages, stream); |
| 383 | + // tok |
| 384 | + let tok_buf = model.tok_buf(); |
| 385 | + stream.memcpy_h2d(tok_buf, &[tok]); |
| 386 | + // pos 递增(prefill 从 0 开始,decode 从 prefill 末尾继续) |
| 387 | + let pos_buf = model.pos_buf(); |
| 388 | + let cur = self.next_pos; |
| 389 | + stream.memcpy_h2d(pos_buf, &[cur]); |
| 390 | + self.next_pos = cur.saturating_add(1); |
| 391 | + // out_idx 固定为 0 |
| 392 | + let out_idx_buf = model.input_buf_at(2); |
| 393 | + stream.memcpy_h2d(out_idx_buf, &[0u32]); |
| 394 | + (key, model.tok_buf()) |
| 395 | + } |
| 396 | + |
| 397 | + /// 设置下一步写入的起始位置(用于显式对齐 prefill→decode) |
| 398 | + pub fn set_decode_start_pos(&mut self, start: u32) { |
| 399 | + self.next_pos = start; |
| 400 | + } |
| 401 | + |
| 402 | + /// 单步增量:执行一步,返回隐藏态(仅一个位置) |
| 403 | + pub fn launch_step( |
| 404 | + &mut self, |
| 405 | + key: NonZeroUsize, |
| 406 | + handle: &mut Handle, |
| 407 | + stream: &Stream<'ctx>, |
| 408 | + ) -> Tensor<*const VirByte, 2> { |
| 409 | + // 目前复用现有图执行路径(n_tok=1),后续接入 Step::Mamba 内核以就地更新状态 |
| 410 | + self.internal |
| 411 | + .get_mut(&key) |
| 412 | + .unwrap() |
| 413 | + .launch(handle, &[], stream) |
| 414 | + } |
| 415 | + |
| 416 | + /// 单步增量(预留):带 mamba cache 的版本 |
| 417 | + pub fn launch_step_with_cache( |
| 418 | + &mut self, |
| 419 | + key: NonZeroUsize, |
| 420 | + cache: &mut MambaCache, |
| 421 | + cache_pos: usize, |
| 422 | + handle: &mut Handle, |
| 423 | + stream: &Stream<'ctx>, |
| 424 | + ) -> Tensor<*const VirByte, 2> { |
| 425 | + // 切换到专用单步:只替换 Mamba Step 的执行,其他仍按原图执行 |
| 426 | + let model = self.internal.get_mut(&key).unwrap(); |
| 427 | + model.launch_with_mamba_cache(handle, cache, cache_pos, stream) |
| 428 | + } |
| 429 | + |
| 430 | + /// Prefill 阶段:带 mamba cache 的版本,用于执行 prefill 并写回状态 |
| 431 | + pub fn launch_prefill_with_cache( |
| 432 | + &mut self, |
| 433 | + key: NonZeroUsize, |
| 434 | + cache: &mut MambaCache, |
| 435 | + handle: &mut Handle, |
| 436 | + stream: &Stream<'ctx>, |
| 437 | + ) -> Tensor<*const VirByte, 2> { |
| 438 | + // 切换到专用 prefill:只替换 Mamba Step 的执行,其他仍按原图执行 |
| 439 | + let model = self.internal.get_mut(&key).unwrap(); |
| 440 | + model.launch_with_mamba_cache(handle, cache, 0, stream) |
| 441 | + } |
| 442 | +} |
0 commit comments