diff --git a/.gitignore b/.gitignore index 882d6465..4e303183 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +/*.txt diff --git a/Cargo.lock b/Cargo.lock index eaf667f3..e8e51c2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,7 +100,7 @@ dependencies = [ [[package]] name = "arg" version = "0.0.0" -source = "git+https://github.com/YdrMaster/InfiniNN?rev=fa8aaf6#fa8aaf6361947ced9d3e624a72336ccfe67aae79" +source = "git+https://github.com/CearX/InfiniNN?rev=6caef2#6caef2f2f4365932bb2c32e9dc1b99e9521c3dbf" dependencies = [ "symbolic-expr", ] @@ -613,7 +613,7 @@ dependencies = [ [[package]] name = "exec" version = "0.0.0" -source = "git+https://github.com/YdrMaster/InfiniNN?rev=fa8aaf6#fa8aaf6361947ced9d3e624a72336ccfe67aae79" +source = "git+https://github.com/CearX/InfiniNN?rev=6caef2#6caef2f2f4365932bb2c32e9dc1b99e9521c3dbf" dependencies = [ "arg", "graph", @@ -830,7 +830,7 @@ checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "graph" version = "0.0.0" -source = "git+https://github.com/YdrMaster/InfiniNN?rev=fa8aaf6#fa8aaf6361947ced9d3e624a72336ccfe67aae79" +source = "git+https://github.com/CearX/InfiniNN?rev=6caef2#6caef2f2f4365932bb2c32e9dc1b99e9521c3dbf" [[package]] name = "h2" @@ -1307,6 +1307,7 @@ dependencies = [ "cuda-cc", "flash-attn", "ggus", + "half", "log", "lru 0.14.0", "memmap2", @@ -1359,7 +1360,7 @@ dependencies = [ [[package]] name = "mem" version = "0.0.0" -source = "git+https://github.com/YdrMaster/InfiniNN?rev=fa8aaf6#fa8aaf6361947ced9d3e624a72336ccfe67aae79" +source = "git+https://github.com/CearX/InfiniNN?rev=6caef2#6caef2f2f4365932bb2c32e9dc1b99e9521c3dbf" dependencies = [ "arg", "exec", @@ -1483,7 +1484,7 @@ checksum = "86a1db1a5ed8293057d401ebd96872cb881a3693d9b55379ae320f652aea3714" [[package]] name = "nn" version = "0.0.0" -source = "git+https://github.com/YdrMaster/InfiniNN?rev=fa8aaf6#fa8aaf6361947ced9d3e624a72336ccfe67aae79" +source = "git+https://github.com/CearX/InfiniNN?rev=6caef2#6caef2f2f4365932bb2c32e9dc1b99e9521c3dbf" dependencies = [ "arg", "graph", @@ -2384,7 +2385,7 @@ dependencies = [ [[package]] name = "tokeneer" version = "0.1.0" -source = "git+https://github.com/InfiniTensor/tokeneer?rev=c48f39f#c48f39f18bbe36a7f727f2f6f06db9c4eccc351a" +source = "git+https://github.com/CearX/tokeneer.git?rev=2546d72#2546d72c0cb12bc15cd40a6ec06f7f9476010155" dependencies = [ "fancy-regex", "ggus", diff --git a/llama.cu/Cargo.toml b/llama.cu/Cargo.toml index 48de65b9..0024d4cb 100644 --- a/llama.cu/Cargo.toml +++ b/llama.cu/Cargo.toml @@ -8,9 +8,9 @@ cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" } cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" } nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6a97931" } flash-attn = { git = "https://github.com/YdrMaster/learn-flash-attn", rev = "616bbac" } -nn = { git = "https://github.com/YdrMaster/InfiniNN", rev = "fa8aaf6" } +nn = { git = "https://github.com/CearX/InfiniNN", rev = "6caef2"} ggus = { git = "https://github.com/InfiniTensor/gguf", rev = "23c362f" } -tokeneer = { git = "https://github.com/InfiniTensor/tokeneer", rev = "c48f39f" } +tokeneer = { git = "https://github.com/CearX/tokeneer.git", rev = "2546d72" } bytesize = "2.0" log.workspace = true @@ -19,6 +19,7 @@ serde.workspace = true memmap2 = "0.9" lru = "0.14" rand = "0.9" +half = "2.3" minijinja = { version = "2.11", default-features = false, features = [ "loader", "builtins", diff --git a/llama.cu/src/exec/group.rs b/llama.cu/src/exec/group.rs index 58783239..7fc040a1 100644 --- a/llama.cu/src/exec/group.rs +++ b/llama.cu/src/exec/group.rs @@ -1,8 +1,10 @@ +use super::mamba_cache::MambaCache; use super::{CacheParts, Progress, model::ModelExec, upos}; use crate::{batch::Req, handle::Handle, load::load_weight, memory::MemPages}; use cuda::{DevByte, DevMem, Stream, VirByte}; use nn::{ - Distribution, Graph, GraphBuilder, LLaMA, NNGraph, Tensor, TensorMeta, digit_layout::types, op, + Distribution, Graph, GraphBuilder, LLaMA, Mamba, NNGraph, Tensor, TensorMeta, + digit_layout::types, op, }; use std::{ collections::BTreeMap, @@ -240,3 +242,133 @@ fn builder() -> GraphBuilder { .register_op("all-reduce", op::all_reduce::AllReduce); ans } + +// Mamba GraphBuilder +fn builder_mamba() -> GraphBuilder { + let mut ans = GraphBuilder::default(); + ans.register_op("embedding", op::embedding::Embedding) + .register_op("rms-norm", op::normalization::RmsNorm) + .register_op("linear", op::linear::Linear) + .register_op("silu", op::activation::SiLU) + .register_op("element-mul", op::element_mul::ElementMul) + .register_op("split", op::split::Split) + .register_op("mamba-causal-conv1d", op::mamba::CausalConv1d) + .register_op("mamba-selective-scan", op::mamba::SelectiveScan); + ans +} + +pub(crate) struct ModelGroupMamba<'ctx> { + internal: Internal<'ctx>, + pages: MemPages, + _weight: DevMem<'ctx>, + next_pos: u32, +} + +impl<'ctx> ModelGroupMamba<'ctx> { + pub fn new>( + mamba: Mamba>, + dist: Distribution, + progress: Option>, + config: ModelGroupConfig, + handle: &mut Handle<'ctx>, + barrier: Option<&Barrier>, + ) -> Self { + let ModelGroupConfig { + static_model_keys, + mut dyn_cache_size, + use_cuda_graph, + } = config; + + let NNGraph(Graph { topo, nodes, edges }) = builder_mamba() + .build( + mamba.tensor_parallel(dist), + [ + TensorMeta::new(types::U32, ["n_tok".into()]), + TensorMeta::new(types::U32, ["n_tok".into()]), + TensorMeta::new(types::U32, ["n_tok".into()]), + ], + ) + .unwrap(); + handle.ctx.stream().synchronize(); + + let dev = handle.ctx.dev(); + let mut pages = MemPages::new(dev); + let (_weight, edges) = load_weight(edges, progress, handle.ctx); + let graph = NNGraph(Graph { topo, nodes, edges }); + let static_models = if use_cuda_graph { + static_model_keys + .into_iter() + .map(|n_tok| { + if let Some(b) = barrier { + b.wait(); + } + let key = NonZeroUsize::new(n_tok).unwrap(); + let exec = ModelExec::new(graph.clone(), n_tok, handle, &mut pages, true); + (key, exec) + }) + .collect::>() + } else { + dyn_cache_size += static_model_keys.into_iter().count(); + Default::default() + }; + + let internal = Internal::new(graph, static_models, dyn_cache_size); + Self { + internal, + pages, + _weight, + next_pos: 0, + } + } + + pub fn load_inputs_mamba_prefill( + &mut self, + handle: &mut Handle<'ctx>, + len: usize, + tok: &[utok], + stream: &Stream<'ctx>, + ) -> (NonZeroUsize, &mut [DevByte]) { + let key = self.internal.get_key(NonZeroUsize::new(len).unwrap()); + let model = self.internal.map_exec(key, handle, &mut self.pages, stream); + stream.memcpy_h2d(model.tok_buf(), &tok[..key.get()]); + let pos: Vec = (0..key.get()).map(|i| i as upos).collect(); + stream.memcpy_h2d(model.pos_buf(), &pos); + self.next_pos = key.get() as u32; + let out_idx: Vec = (0..key.get()).map(|i| i as utok).collect(); + let buf = model.input_buf_at(2); + stream.memcpy_h2d(buf, &out_idx); + (key, model.tok_buf()) + } + + pub fn load_input_mamba_decode( + &mut self, + handle: &mut Handle<'ctx>, + tok: utok, + stream: &Stream<'ctx>, + ) -> (NonZeroUsize, &mut [DevByte]) { + let key = self.internal.get_key(NonZeroUsize::new(1).unwrap()); + let model = self.internal.map_exec(key, handle, &mut self.pages, stream); + let tok_buf = model.tok_buf(); + stream.memcpy_h2d(tok_buf, &[tok]); + let pos_buf = model.pos_buf(); + let cur = self.next_pos; + stream.memcpy_h2d(pos_buf, &[cur]); + // 更新 next_pos + self.next_pos = cur.saturating_add(1); + // decode 时 out_idx 固定为 0 + let out_idx_buf = model.input_buf_at(2); + stream.memcpy_h2d(out_idx_buf, &[0u32]); + (key, model.tok_buf()) + } + + pub fn launch_mamba( + &mut self, + key: NonZeroUsize, + cache: &mut MambaCache, + handle: &mut Handle, + stream: &Stream<'ctx>, + ) -> Tensor<*const VirByte, 2> { + let model = self.internal.get_mut(&key).unwrap(); + model.launch_with_mamba_cache(handle, cache, stream) + } +} diff --git a/llama.cu/src/exec/mamba.rs b/llama.cu/src/exec/mamba.rs new file mode 100644 index 00000000..c8d3d27b --- /dev/null +++ b/llama.cu/src/exec/mamba.rs @@ -0,0 +1,181 @@ +use crate::exec::group::{ModelGroupConfig, ModelGroupMamba}; +use crate::exec::mamba_cache::MambaCache; +use crate::exec::output_head::OutputHead; +use crate::exec::sample_manager::SampleManager; +use crate::memory::MemPages; +use crate::op::random_sample::{KVPair, SampleArgs}; +use crate::utils::{self, meta}; +use crate::{handle::Handle, model::map_files}; +use cuda::Device; +use ggus::GGufMetaMapExt; +use nn::Distribution; +use std::env; +use tokeneer::Bpe; + +#[allow(dead_code)] +pub fn mamba_infer( + model_path: std::path::PathBuf, + text: &str, + use_cuda_graph: bool, +) -> (String, usize) { + use crate::model::GGufModel; + // 初始化 CUDA + assert!(cuda::init().is_ok()); + + // 加载模型 + let maps = map_files(model_path); + let gguf = GGufModel::read(maps.iter().map(|x| &**x)); + let tokenizer = Bpe::from_gguf(&gguf); + let mut tokens = tokenizer.encode(text); + + let n_tok = tokens.len(); + + // 取出输出头用于 logits 计算 + let mut mamba = gguf.mamba(); + let output_head_nn = mamba + .output_head + .take() + .expect("mamba model missing output_head"); + + let n_layer: usize = meta![gguf => llm_block_count]; + let d_inner: usize = 5120; // TODO: ggus + let d_conv: usize = 4; // kernel size + let d_state: usize = 16; // ssm state size + + // 单卡 + let device = Device::new(0); + device.retain_primary().apply(|ctx| { + let mut handle = Handle::new(ctx); + let dist = Distribution { + start: 0, + len: 1, + total: 1, + }; + let mut models = ModelGroupMamba::new( + mamba, + dist, + None, + ModelGroupConfig { + static_model_keys: [n_tok], + dyn_cache_size: 1, + use_cuda_graph, + }, + &mut handle, + None, + ); + + // 组件:输出头与采样器 + let stream = ctx.stream(); + let mut output_head = OutputHead::new(output_head_nn, ctx); + + // 读取词表大小构建采样器与 eos + let eos: tokeneer::utok = meta![gguf => tokenizer_ggml_eos_token_id]; + let nvoc = output_head.nvoc(); + let mut sample_manager = SampleManager::new(nvoc, eos, ctx); + + // 初始化 MambaCache + let mut pages = MemPages::new(device); + let mut mcache = MambaCache::new(n_layer, d_inner, d_conv, d_state, &mut pages); + + // Prefill + let (key, _tok_buf) = + models.load_inputs_mamba_prefill(&mut handle, tokens.len(), &tokens, &stream); + + let mut x = models.launch_mamba(key, &mut mcache, &mut handle, &stream); + + let last_idx: [tokeneer::utok; 1] = [(tokens.len() - 1) as tokeneer::utok]; + let logits_prefill_last = output_head.launch(x.clone(), &last_idx, &mut handle, &stream); + + let logits_prefill_last_vir = logits_prefill_last.as_ref().map(|mem| mem.as_ptr().cast()); + utils::fmt(&logits_prefill_last_vir, stream.ctx()); + // check prefill logits + + let mut next_id: tokeneer::utok; + { + let mut input = stream.malloc::(tokens.len()); + stream.memcpy_h2d(&mut input, &tokens); + let cfg0 = vec![( + crate::batch::SessionId(0), + crate::batch::SampleInfo { + args: SampleArgs::new(0.8, 0.95, 50, 1.2).unwrap(), + input_idx: tokens.len(), + decode_len: tokens.len(), + }, + )]; + let kv_pairs0 = sample_manager.sample(logits_prefill_last, &input, &cfg0, &stream); + stream.free(input); + let mut host_kv0 = vec![KVPair::ZERO; 1]; + stream.memcpy_d2h(&mut host_kv0, &kv_pairs0).free(kv_pairs0); + next_id = host_kv0[0].idx as tokeneer::utok; + } + let mut generated: Vec = Vec::new(); + if next_id != eos { + tokens.push(next_id); + generated.push(next_id); + let (key, _tok_buf) = models.load_input_mamba_decode(&mut handle, next_id, &stream); + x = models.launch_mamba(key, &mut mcache, &mut handle, &stream); + } + + let max_decode_steps: usize = env::var("MAMBA_STEPS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(100); + for _step in 1..max_decode_steps { + let out_idx: [tokeneer::utok; 1] = [0]; + + let logits = output_head.launch(x.clone(), &out_idx, &mut handle, &stream); + + let mut input = stream.malloc::(tokens.len()); + stream.memcpy_h2d(&mut input, &tokens); + let cfg = vec![( + crate::batch::SessionId(0), + crate::batch::SampleInfo { + args: SampleArgs::new(0.8, 0.95, 50, 1.2).unwrap(), + input_idx: tokens.len(), + decode_len: tokens.len(), + }, + )]; + let kv_pairs = sample_manager.sample(logits, &input, &cfg, &stream); + stream.free(input); + let mut host_kv = vec![KVPair::ZERO; 1]; + stream.memcpy_d2h(&mut host_kv, &kv_pairs).free(kv_pairs); + next_id = host_kv[0].idx as tokeneer::utok; + + if next_id == eos { + break; + } + + tokens.push(next_id); + generated.push(next_id); + let (key, _tok_buf) = models.load_input_mamba_decode(&mut handle, next_id, &stream); + x = models.launch_mamba(key, &mut mcache, &mut handle, &stream); + } + + println!("tokens = {:?}", tokens); + let mut text_buf = tokeneer::TextBuf::new(); + let s = tokenizer.decode(&generated, &mut text_buf); + let text = String::from_utf8_lossy(&s.into_bytes()).to_string(); + + (text, tokens.len()) + }) +} + +// #[cfg(test)] +// mod tests { +// use super::*; +// use std::{path::PathBuf, time::Instant}; + +// #[test] +// fn test_mamba_infer_decode() { +// let start = Instant::now(); +// let model = PathBuf::from("/home/cearx/qy/model/Mamba_adf32-2.8B-hf-v1.0-F16.gguf"); +// let prompt = "Once upon a time,"; +// let (text, len) = mamba_infer(model, prompt, false); +// let end = Instant::now(); +// let tokens_per_second = len as f64 / (end - start).as_secs_f64(); +// println!("infer time = {:?}", end - start); +// println!("tokens/s = {}", tokens_per_second); +// println!("prompt = {}", prompt); +// println!("mamba infer text = {}", text); +// } +// } diff --git a/llama.cu/src/exec/mamba_cache.rs b/llama.cu/src/exec/mamba_cache.rs new file mode 100644 index 00000000..5190662f --- /dev/null +++ b/llama.cu/src/exec/mamba_cache.rs @@ -0,0 +1,108 @@ +use crate::memory::MemPages; +use cuda::{VirByte, VirMem}; +use nn::{Tensor, digit_layout::types}; + +/// 每层的 Mamba 状态 +pub struct MambaLayerState { + conv: Tensor, // [d_inner, d_conv] + ssm: Tensor, // [d_inner, d_state] + conv_mapped: usize, + ssm_mapped: usize, +} + +pub struct MambaCache { + pub layers: Box<[MambaLayerState]>, + pub conv_size_per_layer: usize, + pub ssm_size_per_layer: usize, +} + +impl MambaCache { + pub fn new( + n_layers: usize, + d_inner: usize, + d_conv: usize, + d_state: usize, + pages: &mut MemPages, + ) -> Self { + let conv_size_per_layer = d_inner * d_conv * 4; // F32 + let ssm_size_per_layer = d_inner * d_state * 4; // F32 + + let layers = (0..n_layers) + .map(|_| { + let conv_tensor = Tensor::from_dim_slice(types::F32, [d_inner, d_conv]) + .map(|len| pages.reserve_vir(len)); + + let ssm_tensor = Tensor::from_dim_slice(types::F32, [d_inner, d_state]) + .map(|len| pages.reserve_vir(len)); + + MambaLayerState { + conv: conv_tensor, + ssm: ssm_tensor, + conv_mapped: 0, + ssm_mapped: 0, + } + }) + .collect(); + + let mut cache = Self { + layers, + conv_size_per_layer, + ssm_size_per_layer, + }; + + // 立即映射所有层的物理页 + for layer_idx in 0..n_layers { + cache.ensure_mapped(layer_idx, pages); + } + + cache + } + + /// 更新 conv cache 的物理页映射 + pub fn update_conv_mapping(&mut self, layer_idx: usize, pages: &mut MemPages) { + let layer = &mut self.layers[layer_idx]; + let page_size = pages.page_size(); + let target = self.conv_size_per_layer.div_ceil(page_size); + + let mem = layer.conv.get_mut(); + use std::cmp::Ordering::{Equal, Greater, Less}; + match layer.conv_mapped.cmp(&target) { + Less => pages.map(mem, layer.conv_mapped..target), + Greater => pages.unmap(mem, target..layer.conv_mapped), + Equal => {} + } + layer.conv_mapped = target; + } + + /// 更新 ssm cache 的物理页映射 + pub fn update_ssm_mapping(&mut self, layer_idx: usize, pages: &mut MemPages) { + let layer = &mut self.layers[layer_idx]; + let page_size = pages.page_size(); + let target = self.ssm_size_per_layer.div_ceil(page_size); + + let mem = layer.ssm.get_mut(); + use std::cmp::Ordering::{Equal, Greater, Less}; + match layer.ssm_mapped.cmp(&target) { + Less => pages.map(mem, layer.ssm_mapped..target), + Greater => pages.unmap(mem, target..layer.ssm_mapped), + Equal => {} + } + layer.ssm_mapped = target; + } + + /// 获取 conv 状态的虚拟地址张量 + pub fn conv_tensor(&self, layer_idx: usize) -> Tensor<*const VirByte, 2> { + self.layers[layer_idx].conv.as_ref().map(|vir| vir.as_ptr()) + } + + /// 获取 ssm 状态的虚拟地址张量 + pub fn ssm_tensor(&self, layer_idx: usize) -> Tensor<*const VirByte, 2> { + self.layers[layer_idx].ssm.as_ref().map(|vir| vir.as_ptr()) + } + + /// 确保指定层的物理页已映射 + pub fn ensure_mapped(&mut self, layer_idx: usize, pages: &mut MemPages) { + self.update_conv_mapping(layer_idx, pages); + self.update_ssm_mapping(layer_idx, pages); + } +} diff --git a/llama.cu/src/exec/mamba_engine.rs b/llama.cu/src/exec/mamba_engine.rs new file mode 100644 index 00000000..e5270060 --- /dev/null +++ b/llama.cu/src/exec/mamba_engine.rs @@ -0,0 +1,793 @@ +use super::{ + Command, Output, engine_manager::EngineManager, group::ModelGroupMamba, output_head::OutputHead, +}; +use crate::{ + CacheParts, + batch::{Req, Round}, + exec::{group::ModelGroupConfig, mamba_cache::MambaCache, sample_manager::SampleManager, upos}, + handle::Handle, + memory::MemPages, + op::{FastEmbedding, random_sample::KVPair}, + utils::{self, Blob, meta}, +}; +use cuda::{ContextResource, CurrentCtx, Device, Event, HostMem}; +use ggus::GGufMetaMapExt; +use nn::{Distribution, Mamba, Tensor}; +use std::sync::{Mutex, OnceLock}; +use std::{ + ffi::c_int, + iter::zip, + marker::PhantomData, + num::NonZeroUsize, + ops::Deref, + sync::{ + Arc, Barrier, RwLock, + mpsc::{Receiver, Sender}, + }, +}; +use tokeneer::utok; + +#[cfg(nccl)] +use nccl::{Communicator, CommunicatorGroup}; + +// 全局存储用于PPL请求的logprobs +static LOGPROBS_STORAGE: OnceLock>>> = OnceLock::new(); + +/// Check if we should compute logprobs based on the requests +fn should_compute_logprobs(reqs: &[Req]) -> bool { + // 检查是否有 PPL 请求 + // 简单判断:如果有请求,就计算 logprobs + let should_compute = !reqs.is_empty(); + if should_compute { + println!( + "DEBUG: should_compute_logprobs = true, reqs.len() = {}", + reqs.len() + ); + } + should_compute +} + +/// Compute log_softmax on GPU from logits tensor +/// For PPL calculation, we only need logprobs for specific target tokens +fn compute_log_softmax_on_gpu( + logits: &nn::Tensor, + stream: &cuda::Stream, + target_tokens: Option<&[utok]>, +) -> Vec { + // 获取 logits 的维度信息 + let shape = logits.shape(); + let seq_len = shape[0]; + let vocab_size = shape[1]; + let total_elements = seq_len * vocab_size; + println!("DEBUG: logits shape: [{}, {}]", seq_len, vocab_size); + + // 将 logits 从 GPU 复制到 CPU 进行计算 + let d2h = |tensor: &Tensor| { + let mem_range = tensor.layout().data_range(); + let ptr = tensor.get().as_ptr().cast::(); + let len = *mem_range.end() as usize + tensor.dt().nbytes(); + + // 调试:检查内存复制参数 + println!( + "DEBUG: d2h params - mem_range: {:?}, len: {}, dtype_nbytes: {}", + mem_range, + len, + tensor.dt().nbytes() + ); + println!( + "DEBUG: expected_elements: {}, calculated_bytes: {}", + total_elements, + total_elements * 4 + ); // f32 = 4 bytes + + let slice = unsafe { std::slice::from_raw_parts(ptr, len) }; + let mut host = Blob::new(len); + stream.memcpy_d2h(&mut host, slice); + tensor.as_ref().map(|_| host) + }; + + let host_data = d2h(logits); + let host_logits_raw = host_data.as_deref(); + + // 检查数据类型:如果是 f16,需要转换为 f32 + let dtype_size = logits.dt().nbytes(); + println!("DEBUG: dtype size: {} bytes per element", dtype_size); + + let host_logits: Vec = if dtype_size == 2 { + // f16 格式,需要转换 + println!("DEBUG: Converting f16 to f32"); + let host_logits_f16 = unsafe { + std::slice::from_raw_parts( + host_logits_raw.get().as_ptr().cast::(), + total_elements, + ) + }; + + // 调试:验证 f16 原始数据 + println!("DEBUG: f16 raw data - first 10 values:"); + for i in 0..10.min(total_elements) { + println!( + " f16[{}] = {:?} -> f32 = {:.6}", + i, + host_logits_f16[i], + host_logits_f16[i].to_f32() + ); + } + + // 转换为 f32 + host_logits_f16.iter().map(|&x| x.to_f32()).collect() + } else { + // f32 格式,直接使用 + println!("DEBUG: Using f32 directly"); + let host_logits_f32 = unsafe { + std::slice::from_raw_parts(host_logits_raw.get().as_ptr().cast::(), total_elements) + }; + println!("DEBUG: f32 raw data - first 10 values:"); + for i in 0..10.min(total_elements) { + println!(" f32[{}] = {:?}", i, host_logits_f32[i]); + } + host_logits_f32.to_vec() + }; + + // 验证转换后的数据范围 + let min_val = host_logits.iter().fold(f32::INFINITY, |a, &b| a.min(b)); + let max_val = host_logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)); + let avg_val = host_logits.iter().sum::() / total_elements as f32; + println!( + "DEBUG: converted f32 range - min: {:.6}, max: {:.6}, avg: {:.6}", + min_val, max_val, avg_val + ); + + // 检查是否有异常值 + let nan_count = host_logits.iter().filter(|&&x| x.is_nan()).count(); + let inf_count = host_logits.iter().filter(|&&x| x.is_infinite()).count(); + println!( + "DEBUG: anomaly count - NaN: {}, Inf: {}", + nan_count, inf_count + ); + + // 在 CPU 上计算 log_softmax + if let Some(targets) = target_tokens { + // PPL 模式:只计算目标 token 的 logprobs + let mut logprobs = Vec::with_capacity(seq_len); + println!( + "DEBUG: PPL mode - computing logprobs for {} target tokens", + targets.len() + ); + + for batch_idx in 0..seq_len { + if batch_idx >= targets.len() { + break; // 防止越界 + } + + let start_idx = batch_idx * vocab_size; + let batch_logits = &host_logits[start_idx..start_idx + vocab_size]; + let target_token = targets[batch_idx] as usize; + + // 计算这个位置的 log_softmax,但只返回目标 token 的值 + // 1. 找到最大值以提高数值稳定性 + let max_logit = batch_logits + .iter() + .fold(f32::NEG_INFINITY, |a, &b| a.max(b)); + + // 2. 计算 sum(exp(x_j - max)) + let sum_exp: f32 = batch_logits.iter().map(|&x| (x - max_logit).exp()).sum(); + + // 3. 计算 log_sum_exp = max + log(sum_exp) + let log_sum_exp = max_logit + sum_exp.ln(); + + // 4. 只计算目标 token 的 log_softmax + if target_token < vocab_size { + let target_logprob = batch_logits[target_token] - log_sum_exp; + logprobs.push(target_logprob); + + // 调试:打印一些样本值 + if batch_idx < 5 { + println!( + "DEBUG: pos={}, target_token={}, target_logit={:.4}, log_sum_exp={:.4}, logprob={:.4}", + batch_idx, + target_token, + batch_logits[target_token], + log_sum_exp, + target_logprob + ); + } + } else { + println!( + "WARNING: target_token {} >= vocab_size {}", + target_token, vocab_size + ); + logprobs.push(f32::NEG_INFINITY); // 无效 token + } + } + + // 统计 logprobs 的范围 + if !logprobs.is_empty() { + let min_logprob = logprobs.iter().fold(f32::INFINITY, |a, &b| a.min(b)); + let max_logprob = logprobs.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)); + let avg_logprob = logprobs.iter().sum::() / logprobs.len() as f32; + println!( + "DEBUG: Logprobs stats - count: {}, min: {:.4}, max: {:.4}, avg: {:.4}", + logprobs.len(), + min_logprob, + max_logprob, + avg_logprob + ); + } + + println!("DEBUG: Computed {} target token logprobs", logprobs.len()); + logprobs + } else { + // 原始模式:计算所有词汇的 logprobs(兼容性) + let mut logprobs = Vec::with_capacity(total_elements); + + for batch_idx in 0..seq_len { + let start_idx = batch_idx * vocab_size; + let end_idx = start_idx + vocab_size; + let batch_logits = &host_logits[start_idx..end_idx]; + + // 计算 log_softmax + // log_softmax(x_i) = log(exp(x_i) / sum(exp(x_j))) = x_i - log(sum(exp(x_j))) + + // 1. 找到最大值以提高数值稳定性 + let max_logit = batch_logits + .iter() + .fold(f32::NEG_INFINITY, |a, &b| a.max(b)); + + // 2. 计算 sum(exp(x_j - max)) + let sum_exp: f32 = batch_logits.iter().map(|&x| (x - max_logit).exp()).sum(); + + // 3. 计算 log_sum_exp = max + log(sum_exp) + let log_sum_exp = max_logit + sum_exp.ln(); + + // 4. 计算每个位置的 log_softmax + for &logit in batch_logits { + logprobs.push(logit - log_sum_exp); + } + } + + println!("DEBUG: Computed {} total logprobs", logprobs.len()); + logprobs + } +} + +/// Store logprobs in global storage for API access +fn store_logprobs(logprobs: Vec) { + let storage = LOGPROBS_STORAGE.get_or_init(|| Mutex::new(None)); + if let Ok(mut guard) = storage.lock() { + *guard = Some(logprobs); + } +} + +/// Retrieve and clear stored logprobs +pub fn take_stored_logprobs() -> Option> { + let storage = LOGPROBS_STORAGE.get_or_init(|| Mutex::new(None)); + if let Ok(mut guard) = storage.lock() { + guard.take() + } else { + None + } +} + +// Mamba 专用的推理引擎参数 +const MAMBA_NTOKS: [usize; 5] = [1, 8, 32, 128, 512]; +const MAMBA_CHUNKED_PREFILL_LEN: Option = None; // 关闭 chunked prefill 以支持 PPL +const MAMBA_MAX_TOKS: usize = 512; + +pub(crate) fn mamba_engine( + mamba: Mamba>, + gguf: &impl GGufMetaMapExt, + eos: utok, + workers: &[(c_int, Option>)], + commands: Receiver, + outputs: Sender, + use_cuda_graph: bool, +) { + if let &[(gpu, progress)] = &workers { + return mamba_mono( + mamba, + gguf, + eos, + Device::new(*gpu), + progress.clone(), + commands, + outputs, + use_cuda_graph, + ); + } + + #[cfg(not(nccl))] + unreachable!(); + + #[cfg(nccl)] + { + use std::collections::HashMap; + + let devlist = workers.iter().map(|(gpu, _)| *gpu).collect::>(); + let mut workers = workers.iter().cloned().collect::>(); + + let mut comms = CommunicatorGroup::new(&devlist).into_vec().into_iter(); + let first = comms.next().unwrap(); + + let mut mamba = mamba; + let output_head = mamba.output_head.take().unwrap(); + let worker = MambaWorker { + dev: first.device(), + dist: Distribution { + start: 0, + len: 1, + total: devlist.len(), + }, + progress: workers.remove(&first.device().index()).unwrap(), + config: ModelGroupConfig { + static_model_keys: MAMBA_NTOKS, + dyn_cache_size: 1, + use_cuda_graph, + }, + max_toks: MAMBA_MAX_TOKS, + barrier: Some(Arc::new(Barrier::new(devlist.len()))), + task_box: Default::default(), + chunked_prefill_len: MAMBA_CHUNKED_PREFILL_LEN, + }; + std::thread::scope(|s| { + let _threads = comms + .map(|comm| { + let dev = comm.device(); + let dist = Distribution::new(comm.rank(), 1, devlist.len()); + let worker = MambaWorker { + dev, + dist, + progress: workers.remove(&dev.index()).unwrap(), + ..worker.clone() + }; + let mamba = mamba.clone(); + s.spawn(move || worker.work(mamba, comm)) + }) + .collect::>(); + + worker.lead(mamba, gguf, eos, output_head, commands, outputs, |ctx| { + Handle::with_comm(ctx, first) + }) + }) + } +} + +fn mamba_mono( + mut mamba: Mamba>, + gguf: &impl GGufMetaMapExt, + eos: utok, + dev: Device, + progress: Option>, + commands: Receiver, + outputs: Sender, + use_cuda_graph: bool, +) { + let output_head = mamba.output_head.take().unwrap(); + MambaWorker { + dev, + dist: Distribution { + start: 0, + len: 1, + total: 1, + }, + progress, + config: ModelGroupConfig { + static_model_keys: MAMBA_NTOKS, + dyn_cache_size: 1, + use_cuda_graph, + }, + max_toks: MAMBA_MAX_TOKS, + barrier: None, + task_box: Default::default(), + chunked_prefill_len: MAMBA_CHUNKED_PREFILL_LEN, + } + .lead(mamba, gguf, eos, output_head, commands, outputs, |ctx| { + Handle::new(ctx) + }) +} + +#[derive(Clone)] +struct MambaWorker { + dev: Device, + dist: Distribution, + progress: Option>, + config: ModelGroupConfig, + max_toks: usize, + barrier: Option>, + task_box: MambaTaskBox, + chunked_prefill_len: Option, +} + +type MambaTaskBox = Arc>>; + +#[cfg_attr(not(nccl), allow(dead_code))] +struct MambaTask { + key: NonZeroUsize, + reqs: Vec>, +} + +impl> MambaWorker { + fn lead( + self, + mamba: Mamba>, + gguf: &impl GGufMetaMapExt, + eos: utok, + output_head: nn::OutputHead>, + commands: Receiver, + outputs: Sender, + handle: impl FnOnce(&CurrentCtx) -> Handle, + ) { + let Self { + dev, + dist, + progress, + config, + max_toks, + barrier, + task_box, + chunked_prefill_len, + } = self; + + dev.set_mempool_threshold(u64::MAX); + dev.retain_primary().apply(|ctx| { + let mut handle = handle(ctx); + + // 初始化 Mamba 模型组 + let mut models = ModelGroupMamba::new( + mamba, + dist, + progress.clone(), + config, + &mut handle, + barrier.as_deref(), + ); + + let mut manager = EngineManager::new(chunked_prefill_len, max_toks); + let mut output_head = OutputHead::new(output_head, ctx); + let mut sample_manager = SampleManager::new(output_head.nvoc(), eos, ctx); + + // 初始化 Mamba 缓存 + let n_layer: usize = meta![gguf => llm_block_count]; + let d_inner: usize = 5120; // TODO: 从权重/元数据推导 + let d_conv: usize = 4; // kernel size + let d_state: usize = 16; // ssm state size + + let mut pages = MemPages::new(dev); + let mut mamba_cache = MambaCache::new(n_layer, d_inner, d_conv, d_state, &mut pages); + + let max_tok = max_toks; + let mut fast_embd = FastEmbedding::new(max_tok, ctx); + let mut pre_kv_pairs = ctx.malloc::(max_tok); + + let stream = ctx.stream(); + let len = max_toks; + const BUF_LEVEL: usize = 3; + let mut events: [Event; BUF_LEVEL] = std::array::from_fn(|_| stream.record()); + let mut tok_buf = BufN::::new(len, BUF_LEVEL, ctx); + let mut pos_buf = BufN::::new(len, BUF_LEVEL, ctx); + let mut out_idx_buf = BufN::::new(len, BUF_LEVEL, ctx); + let mut fast_embd_buf = BufN::<(utok, utok)>::new(len, BUF_LEVEL, ctx); + + if outputs.send(Output::Ready).is_ok() { + while let Ok(removed) = manager.receive(&commands, &outputs) { + // 处理已移除会话 + sample_manager.remove(removed); + // 组织请求 + let Round { + overflow, + tokens, + reqs, + sample, + output, + fast_map, + finished, + } = manager.prepare(); + // 处理缓存溢出 + sample_manager.remove(overflow.iter().map(|s| s.id)); + if !overflow.is_empty() + && outputs.send(Output::Overflow(overflow.into())).is_err() + { + break; + } + // 如果不需要推理 + if tokens.is_empty() { + assert!( + reqs.is_empty() + && sample.is_empty() + && output.is_empty() + && fast_map.is_empty() + && finished.is_empty() + ); + continue; + } + // 更新 host 多级缓存 + let out_idx = out_idx(&reqs, output.iter().map(|(_, len)| *len)); + events[out_idx_buf.index()].synchronize(); + tok_buf.save(&tokens); + pos_buf.save(&pos(&reqs)); + out_idx_buf.save(&out_idx); + fast_embd_buf.save(&fast_map); + events[out_idx_buf.index()] = stream.record(); + + // 加载输入 - 区分 prefill 和 decode + let (key, tok) = if reqs.len() == 1 && reqs[0].seq == tokens.len() { + // Prefill 阶段 + models.load_inputs_mamba_prefill( + &mut handle, + tokens.len(), + &tokens, + &stream, + ) + } else if tokens.len() == 1 { + // Decode 阶段 (单个 token) + models.load_input_mamba_decode(&mut handle, tokens[0], &stream) + } else { + // 多个 token 但不是完整 prefill,使用 prefill 方法 + models.load_inputs_mamba_prefill( + &mut handle, + tokens.len(), + &tokens, + &stream, + ) + }; + + // 快速启动路径 + fast_embd.launch( + tok, + &pre_kv_pairs, + &fast_embd_buf[..fast_map.len()], + &mut handle, + &stream, + ); + let mut input = stream.malloc::(tok.len() / size_of::()); + stream.memcpy_d2d(&mut input, tok); + + // 通知协处理单元 + #[cfg(nccl)] + if let Some(barrier) = &barrier { + *task_box.write().unwrap() = Some(MambaTask { + key, + reqs: reqs.clone(), + }); + barrier.wait(); + models.share_inputs(key, &mut handle, &stream); + } + + // Mamba 推理 + let x = models.launch_mamba(key, &mut mamba_cache, &mut handle, &stream); + + println!("DEBUG: x shape: {:?}", x.shape()); + + // 对于 PPL 计算,我们需要所有位置的 logits(除了最后一个位置) + let need_logprobs = should_compute_logprobs(&reqs); + let effective_out_idx = if need_logprobs { + // PPL 需要计算位置 0..seq_len-1 的 logits 来预测位置 1..seq_len 的 token + let seq_len = tokens.len(); + if seq_len > 1 { + (0..seq_len - 1).map(|i| i as utok).collect() + } else { + out_idx.clone() + } + } else { + out_idx.clone() + }; + + // 如果没有需要计算的位置,则跳过 + if effective_out_idx.is_empty() { + continue; + } + + println!("DEBUG: original tokens: {:?}", tokens); + println!("DEBUG: token len: {:?}", tokens.len()); + println!("DEBUG: out_idx len: {:?}", out_idx.len()); + println!( + "DEBUG: effective_out_idx len: {:?}", + effective_out_idx.len() + ); + println!("DEBUG: out_idx: {:?}", out_idx); + println!("DEBUG: effective_out_idx: {:?}", effective_out_idx); + + let logits_prefill_last = output_head.launch( + x.clone(), + // &out_idx_buf[..out_idx.len()], + &[0], // 打印logits first token + &mut handle, + &stream, + ); + let logits_prefill_last_vir = + logits_prefill_last.as_ref().map(|mem| mem.as_ptr().cast()); + utils::fmt(&logits_prefill_last_vir, stream.ctx()); // 打印logits_prefill_last + + // 计算输出头 - 这里可以计算 logprobs + let logits = output_head.launch(x, &effective_out_idx, &mut handle, &stream); + + // 计算真实的 logprobs (log_softmax) + if need_logprobs { + println!("DEBUG: Computing logprobs..."); + + // 对于 PPL 计算,我们需要目标 token(即输入序列的下一个位置) + let target_tokens: Vec = if tokens.len() > 1 { + // 目标 token 是位置 1..seq_len 的 token(用于计算位置 0..seq_len-1 的概率) + let targets = tokens[1..].to_vec(); + println!("DEBUG: PPL target tokens: {:?}", targets); + targets + } else { + Vec::new() + }; + + // 调试:检查 GPU logits tensor 信息 + let logits_shape = logits.shape(); + println!("DEBUG: GPU logits tensor shape: {:?}", logits_shape); + println!("DEBUG: GPU logits tensor dtype: {:?}", logits.dt()); + + let logprobs = if !target_tokens.is_empty() { + compute_log_softmax_on_gpu(&logits, &stream, Some(&target_tokens)) + } else { + compute_log_softmax_on_gpu(&logits, &stream, None) + }; + + println!("DEBUG: Computed {} logprobs", logprobs.len()); + // 将 logprobs 存储到全局存储中 + store_logprobs(logprobs); + println!("DEBUG: Stored logprobs"); + } + + // 跳过采样,创建空的 kv_pairs + let kv_pairs = if logits.dt() == nn::digit_layout::types::F32 { + println!("Skipping sampling for f32 logits"); + // 创建大小为 0 的 KVPair DevMem + stream.malloc::(1) + } else { + // 正常采样 + sample_manager.sample(logits, &input, &sample, &stream) + }; + + // 采样 + // let kv_pairs = sample_manager.sample(logits, &input, &sample, &stream); + stream.free(input); + stream.memcpy_d2d(&mut pre_kv_pairs[..kv_pairs.len()], &kv_pairs); + + // 处理推理结束 + sample_manager.remove(finished.iter().map(|s| s.id)); + + // 生成并发送输出 + let output = output + .into_iter() + .filter_map(|(id, len)| if len > 0 { Some((id, len)) } else { None }) + .collect(); + let output = Output::Complete { + output, + kv_pair: kv_pairs.sporulate(), + event: stream.record().sporulate(), + finished: finished.into(), + }; + if outputs.send(output).is_err() { + break; + } + } + } + + // 通知协处理单元退出 + if let Some(barrier) = &barrier { + let _ = task_box.write().unwrap().take(); + barrier.wait(); + } + + // 送回存储的会话信息 + for stub in manager.into_stubs() { + if outputs.send(Output::Removed(stub.session)).is_err() { + break; + } + } + }) + } + + #[cfg(nccl)] + fn work(self, mamba: Mamba>, comm: Communicator) { + let Self { + dev, + dist, + progress, + config, + max_toks: _max_toks, + barrier, + task_box, + .. + } = self; + + let barrier = barrier.unwrap(); + dev.set_mempool_threshold(u64::MAX); + dev.retain_primary().apply(|ctx| { + let mut handle = Handle::with_comm(ctx, comm); + let mut models = + ModelGroupMamba::new(mamba, dist, progress, config, &mut handle, Some(&barrier)); + + let stream = ctx.stream(); + loop { + barrier.wait(); + match &*task_box.read().unwrap() { + Some(MambaTask { key, reqs }) => { + models.share_inputs(*key, &mut handle, &stream); + // TODO: 需要实现 Mamba 的 launch 方法 + // models.launch_mamba(*key, reqs, &mut handle, &stream); + } + None => break, + } + } + }) + } +} + +fn pos(reqs: &[Req]) -> Vec { + reqs.iter() + .flat_map(|req| req.pos..req.pos + req.seq) + .map(|x| x as _) + .collect() +} + +fn out_idx(reqs: &[Req], outs: impl IntoIterator) -> Vec { + let mut out_idx = Vec::new(); + + let mut itok = 0; + for (req, out) in zip(reqs, outs) { + for i in req.seq - out..req.seq { + out_idx.push((itok + i) as _) + } + itok += req.seq + } + + out_idx +} + +struct BufN<'ctx, T> { + buf: HostMem<'ctx>, + index: usize, + level: usize, + _phantom: PhantomData, +} + +impl<'ctx, T: Copy> BufN<'ctx, T> { + fn new(len: usize, level: usize, ctx: &'ctx CurrentCtx) -> Self { + Self { + buf: ctx.malloc_host::(len * level), + index: 0, + level, + _phantom: PhantomData, + } + } +} + +impl BufN<'_, T> { + fn save(&mut self, data: &[T]) { + let data = unsafe { std::slice::from_raw_parts(data.as_ptr().cast(), size_of_val(data)) }; + + if self.index + 1 == self.level { + self.index = 0 + } else { + self.index += 1 + } + + let piece = self.buf.len() / self.level; + let (data_, padding) = self.buf[self.index * piece..][..piece].split_at_mut(data.len()); + data_.copy_from_slice(data); + padding.fill(0) + } + + const fn index(&self) -> usize { + self.index + } +} + +impl Deref for BufN<'_, T> { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + let piece = self.buf.len() / self.level; + let (&[], piece, &[]) = + (unsafe { self.buf[self.index * piece..][..piece].align_to::() }) + else { + unreachable!() + }; + piece + } +} diff --git a/llama.cu/src/exec/mod.rs b/llama.cu/src/exec/mod.rs index b5676c1c..6738fbd2 100644 --- a/llama.cu/src/exec/mod.rs +++ b/llama.cu/src/exec/mod.rs @@ -2,6 +2,9 @@ mod engine_manager; mod group; mod kv_cache; +mod mamba; +mod mamba_cache; +mod mamba_engine; mod model; mod output_head; mod sample_manager; @@ -22,6 +25,8 @@ type upos = u32; pub use engine::Progress; pub(crate) use engine::engine; pub(crate) use kv_cache::KVCache; +pub(crate) use mamba_engine::mamba_engine; +pub use mamba_engine::take_stored_logprobs; pub(crate) enum Command { ShutDown, diff --git a/llama.cu/src/exec/model.rs b/llama.cu/src/exec/model.rs index 810f142f..3cf690ea 100644 --- a/llama.cu/src/exec/model.rs +++ b/llama.cu/src/exec/model.rs @@ -1,4 +1,8 @@ -use super::step::Step; +use super::mamba_cache::MambaCache; +use super::step::Step; +use crate::op::Operator; +use crate::op::conv1d::Conv1dWriteStateStep; +use crate::op::scan::SelectiveScanWithWriteback; use crate::{ batch::Req, handle::Handle, @@ -95,6 +99,11 @@ impl ModelExec<'_> { as_mapped(&self.inputs[1]) } + /// 获取第 idx 个全局输入缓冲区用于 Mamba + pub fn input_buf_at(&mut self, idx: usize) -> &mut [DevByte] { + as_mapped(&self.inputs[idx]) + } + pub fn launch( &mut self, handle: &mut Handle, @@ -120,6 +129,126 @@ impl ModelExec<'_> { destruct!([x] = self.outputs.clone()); x } + + // 捕获 conv 和 ssm exec, 更新 & 使用 mamba cache + pub fn launch_with_mamba_cache( + &mut self, + handle: &mut Handle, + cache: &mut MambaCache, + stream: &Stream, + ) -> Tensor<*const VirByte, 2> { + for exec in &self.execs { + match exec { + Step::Graph(graph, stub) => { + stream.launch_graph(graph); + if !stub.is_empty() { + for t in stub { + utils::fmt(t, stream.ctx()) + } + std::process::exit(0); + } + } + Step::Attention(_box_) => {} + Step::Exec(exec) => { + let ty = &exec.node.value.name; + if ty == "mamba-causal-conv1d" || ty == "mamba-selective-scan" { + // 解析层号:形如 "Ω.blk{N}.xxx" + let name = &exec.node.name; + let iblk = { + let prefix = "Ω.blk"; + if let Some(pos) = name.find(prefix) { + let mut i = pos + prefix.len(); + let bytes = name.as_bytes(); + let mut val: usize = 0; + while i < bytes.len() { + let c = bytes[i] as char; + if let Some(d) = c.to_digit(10) { + val = val * 10 + d as usize; + i += 1; + } else { + break; + } + } + val + } else { + panic!("missing layer index in node name: {}", name) + } + }; + + if ty == "mamba-causal-conv1d" { + let n = exec.inputs[0].shape()[0]; + if n == 1 { + // decode: 就地更新 + let inputs = [ + exec.inputs[0].clone(), // x_t [1,d] + exec.inputs[1].clone(), // w [d,k] + exec.inputs[2].clone(), // b [d] + cache.conv_tensor(iblk), // state [d,k] (F32) + ]; + let outputs = [exec.outputs[0].clone()]; + + crate::op::CausalConv1dStep::launch( + handle, None, inputs, outputs, stream, + ); + } else { + // prefill: 更新状态再计算 + let inputs = [ + exec.inputs[0].clone(), // x [n,d] + cache.conv_tensor(iblk), // state [d,k] + ]; + Conv1dWriteStateStep::launch( + handle, + None, + inputs, + std::iter::empty(), + stream, + ); + + handle.launch_nn_exec(exec, stream); + } + } else { + let n = exec.inputs[0].shape()[0]; + if n == 1 { + // decode + let inputs = [ + exec.inputs[0].clone(), // u_t [1,d] + exec.inputs[1].clone(), // delta_t [1,d] + exec.inputs[2].clone(), // A [d,n_state] + exec.inputs[3].clone(), // B [1,n_state] + exec.inputs[4].clone(), // C [1,n_state] + exec.inputs[5].clone(), // D [d] + cache.ssm_tensor(iblk), // state [d,n_state] (F32) + ]; + let outputs = [exec.outputs[0].clone()]; + SelectiveScanWithWriteback::launch( + handle, None, inputs, outputs, stream, + ); + } else { + // prefill + let inputs = [ + exec.inputs[0].clone(), // u [n,d] + exec.inputs[1].clone(), // delta [n,d] + exec.inputs[2].clone(), // A [d,n_state] + exec.inputs[3].clone(), // B [n,n_state] + exec.inputs[4].clone(), // C [n,n_state] + exec.inputs[5].clone(), // D [d] + cache.ssm_tensor(iblk), // state [d,n_state] + ]; + let outputs = [exec.outputs[0].clone()]; + SelectiveScanWithWriteback::launch( + handle, None, inputs, outputs, stream, + ); + } + } + } else { + handle.launch_nn_exec(exec, stream); + } + } + } + } + destruct!([x] = self.outputs.clone()); + x + } } #[allow(clippy::mut_from_ref)] diff --git a/llama.cu/src/exec/step.rs b/llama.cu/src/exec/step.rs index a827b0d1..43fe0fd6 100644 --- a/llama.cu/src/exec/step.rs +++ b/llama.cu/src/exec/step.rs @@ -118,10 +118,14 @@ impl<'ctx> Handle<'ctx> { "rms-norm" => launch!(RmsNorm), "layer-norm" => launch!(LayerNorm), "linear" => launch!(Linear), + "mamba-causal-conv1d" => launch!(Conv1d), + "mamba-selective-scan" => launch!(SelectiveScanWithWriteback), "rope" => launch!(Rope), "mrope" => launch!(MRope), "gelu" => launch!(Gelu), "swiglu" => launch!(Swiglu), + "silu" => launch!(Silu), + "element-mul" => launch!(ElementMul), #[cfg(nccl)] "all-reduce" => launch!(AllReduce), "empty" => {} diff --git a/llama.cu/src/lib.rs b/llama.cu/src/lib.rs index 2011a175..a32074fd 100644 --- a/llama.cu/src/lib.rs +++ b/llama.cu/src/lib.rs @@ -8,8 +8,8 @@ mod op; mod utils; use cuda::{self, Device}; -use exec::{Command, KVCache, Output, Request, engine}; -use ggus::GGufMetaMapExt; +use exec::{Command, KVCache, Output, Request, engine, mamba_engine}; +use ggus::{GGufMetaMap, GGufMetaMapExt}; use log::info; use memory::MemPages; use model::{ChatTemplate, GGufModel, map_files}; @@ -32,9 +32,16 @@ use utils::meta; pub use crate::op::random_sample::SampleArgs; pub use batch::{Cache, Session, SessionId}; pub use exec::Progress; +pub use exec::take_stored_logprobs; pub use model::Message; pub use tokeneer::{TextBuf, utok}; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ModelType { + LLaMA, + Mamba, +} + pub struct Service { handle: Option<(Receiver, std::thread::JoinHandle<()>)>, ready: bool, @@ -75,6 +82,19 @@ struct ModelComponents { impl Service { pub fn new(model: impl AsRef, gpus: &[c_int], use_cuda_grpah: bool) -> Self { + Self::new_with_model_type(model, gpus, use_cuda_grpah, None) + } + + pub fn new_mamba(model: impl AsRef, gpus: &[c_int], use_cuda_grpah: bool) -> Self { + Self::new_with_model_type(model, gpus, use_cuda_grpah, Some(ModelType::Mamba)) + } + + fn new_with_model_type( + model: impl AsRef, + gpus: &[c_int], + use_cuda_grpah: bool, + model_type: Option, + ) -> Self { info!("start inference @gpu{gpus:?}"); // 创建调度通道 let (outputs, receiver) = mpsc::channel(); @@ -95,23 +115,66 @@ impl Service { let once_ = once.clone(); let handle = std::thread::spawn(move || { let mut gguf = GGufModel::read(maps.iter().map(|x| &**x)); - gguf.insert_rope_sin_cos(); - let tokenizer = Bpe::from_gguf(&gguf); let chat_template = gguf.chat_template(&tokenizer); - let cache_template = gguf.lm_kv_cache(); let eos = meta![gguf => tokenizer_ggml_eos_token_id]; - once_.get_or_init(|| ModelComponents { - tokenizer, - chat_template, - cache_template, - eos, - }); - drop(once_); + // 检测模型类型 + let arch = meta![gguf => general_architecture]; + let model_type = match model_type { + Some(model_type) => model_type, + None => { + if arch == "mamba" { + ModelType::Mamba + } else { + ModelType::LLaMA + } + } + }; - let llama = gguf.llama(); - engine(llama, eos, &workers, commands, outputs, use_cuda_grpah) + match model_type { + ModelType::LLaMA => { + gguf.insert_rope_sin_cos(); + let cache_template = gguf.lm_kv_cache(); + once_.get_or_init(|| ModelComponents { + tokenizer, + chat_template, + cache_template, + eos, + }); + drop(once_); + + let llama = gguf.llama(); + engine(llama, eos, &workers, commands, outputs, use_cuda_grpah) + } + ModelType::Mamba => { + // Mamba 不需要 RoPE 和 KV cache, + // kv cache 影响测 ppl, 此处设为 max_tokens + let max_tokens = 1024; + let cache_template = Tensor::from_dim_slice( + nn::digit_layout::types::U64, + [max_tokens, 2, 2, 4, 128], + ); // [nctx, nblk, 2, nkvh, dh] + once_.get_or_init(|| ModelComponents { + tokenizer, + chat_template, + cache_template, + eos, + }); + drop(once_); + + let mamba = gguf.mamba(); + mamba_engine( + mamba, + &gguf, + eos, + &workers, + commands, + outputs, + use_cuda_grpah, + ) + } + } }); once.wait(); Self { diff --git a/llama.cu/src/model/mamba.rs b/llama.cu/src/model/mamba.rs new file mode 100644 index 00000000..2718d7a6 --- /dev/null +++ b/llama.cu/src/model/mamba.rs @@ -0,0 +1,113 @@ +use super::GGufModel; +use crate::utils::meta; +use ggus::GGufMetaMapExt; +use nn::Tensor; + +impl GGufModel<'_> { + /// 构造 mamba 模型 + pub fn mamba(&self) -> nn::Mamba> { + let nvoc = meta![self => tokenizer_ggml_tokens].len(); + let nblk = meta![self => llm_block_count]; + let d = meta![self => llm_embedding_length]; + let epsilon = meta![self => llm_attention_layer_norm_rms_epsilon; 1e-5]; + let dt_embd = self.tensors["token_embd.weight"].dt(); + let dt_norm = self.tensors["output_norm.weight"].dt(); + let dt_linear = self.tensors["blk.0.ssm_in.weight"].dt(); + + let d_kernel = 4; + let d_inner = 5120; + let d_state = 16; + let dt_rank = 160; // ggus todo: mamba.ssm. + + let get = |name: &str| self.tensors[name].as_deref(); + + ::nn::Mamba { + embedding: ::nn::Embedding { + dt: dt_embd, + d, + wte: ::nn::Table { + row: nvoc, + weight: get("token_embd.weight"), + }, + wpe: None, + }, + blks: (0..nblk) + .map(|iblk| ::nn::MambaBlock { + mamba_norm: nn::Normalization { + d, + epsilon: epsilon as _, + items: ::nn::NormType::RmsNorm { + dt: dt_norm, + scale: get(&format!("blk.{iblk}.attn_norm.weight")), + }, + }, + mamba_mixer: nn::MambaMixer { + d_inner, + in_proj: nn::Linear { + dt: dt_linear, + shape: [d_inner * 2, d], + weight: get(&format!("blk.{iblk}.ssm_in.weight")), + bias: None, + allow_residual: false, + }, + causal_conv1d: nn::CausalConv1d::new( + dt_norm, + get(&format!("blk.{iblk}.ssm_conv1d.weight")), + get(&format!("blk.{iblk}.ssm_conv1d.bias")), + d_kernel, + d_inner, + ), + act: nn::Activation::SiLU, + selective_ssm: nn::SelectiveSSM { + dt: dt_norm, + d_state, + dt_rank, + dt_proj: nn::Linear { + dt: dt_linear, + shape: [d_inner, dt_rank], + weight: get(&format!("blk.{iblk}.ssm_dt.weight")), + bias: Some(dt_linear) + .map(|dt| (dt, get(&format!("blk.{iblk}.ssm_dt.bias")))), + allow_residual: false, + }, + x_proj: nn::Linear { + dt: dt_linear, + shape: [dt_rank + d_state * 2, d_inner], + weight: get(&format!("blk.{iblk}.ssm_x.weight")), + bias: None, + allow_residual: false, + }, + a: get(&format!("blk.{iblk}.ssm_a")), + d: get(&format!("blk.{iblk}.ssm_d")), + }, + out_proj: nn::Linear { + dt: dt_linear, + shape: [d, d_inner], + weight: get(&format!("blk.{iblk}.ssm_out.weight")), + bias: None, + allow_residual: true, + }, + }, + }) + .collect(), + output_head: Some(::nn::OutputHead { + out_norm: ::nn::Normalization { + d, + epsilon: epsilon as _, + items: ::nn::NormType::RmsNorm { + dt: dt_norm, + scale: get("output_norm.weight"), + }, + }, + lm_head: { + let out_linear = if self.tensors.contains_key("output.weight") { + get("output.weight") + } else { + get("token_embd.weight") + }; + ::nn::Linear::new(dt_embd, [nvoc, d], out_linear, None) + }, + }), + } + } +} diff --git a/llama.cu/src/model/mod.rs b/llama.cu/src/model/mod.rs index 67987b42..838e8466 100644 --- a/llama.cu/src/model/mod.rs +++ b/llama.cu/src/model/mod.rs @@ -1,5 +1,6 @@ mod chat_template; mod llama; +mod mamba; mod qw2vl_mmproj; use crate::utils::{Blob, Data}; diff --git a/llama.cu/src/op/conv1d.cuh b/llama.cu/src/op/conv1d.cuh new file mode 100644 index 00000000..4b36509f --- /dev/null +++ b/llama.cu/src/op/conv1d.cuh @@ -0,0 +1,112 @@ +// prefill: 将 full-seq 的最后 K 步写入 conv state,左侧零填充到 K +// x: [n, d], state: [d, K] +template +static __device__ void kernel_write_state( + Tdata const* __restrict__ x, + int const s_n_x, + int const s_d_x, + float* __restrict__ state, + int const s_state_c, + int const s_state_k, + int const n, + int const d, + int const k) +{ + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c >= d) return; + + float* row = state + c * s_state_c; + const int pad = (k > n) ? (k - n) : 0; + // 前 pad 个位置写 0 + for (int j = 0; j < pad; ++j) { + row[j * s_state_k] = 0.f; + } + // 剩余位置拷贝最后 min(n,k) 步 + const int copy = (k - pad); + for (int j = 0; j < copy; ++j) { + const int t = n - copy + j; // 对应 x 的时间步索引 + const int xi = t * s_n_x + c * s_d_x; + row[(pad + j) * s_state_k] = static_cast(x[xi]); + } +} + +// x[n, d], groups=d, w[d, k], b[d] +template +static __device__ void kernel( + Tdata* __restrict__ y, + int const s_n_y, + int const s_d_y, + Tdata const* __restrict__ x, + int const s_n_x, + int const s_d_x, + Twb const* __restrict__ w, + int const s_d_w, + int const s_k_w, + Twb const* __restrict__ b, + int const s_d_b, + int const kernel_size, + int const padding) { + + const int c = blockIdx.x; + const int pos = blockIdx.y * blockDim.x + threadIdx.x; + + const int seq_len = gridDim.y * blockDim.x; + if (pos >= seq_len) return; + + Tdata sum = Tdata(0); + for (int k = 0; k < kernel_size; ++k) { + const int l_in = pos + k - padding; + if (l_in >= 0 && l_in < seq_len) { + const int x_pos = l_in * s_n_x + c * s_d_x; + const int w_pos = c * s_d_w + k * s_k_w; + sum += Tdata(x[x_pos]) * Tdata(w[w_pos]); + } + } + + sum += Tdata(b[c * s_d_b]); + + const int y_pos = pos * s_n_y + c * s_d_y; + y[y_pos] = sum; +} + +// decode: 输入 x_t[n,d] (n = seq = 1),就地读写 state [d, K],并计算 y[n, d] +template +static __device__ void kernel_step( + Tdata* __restrict__ y, + int const s_n_y, + int const s_d_y, + Tdata const* __restrict__ x_t, + int const s_n_x, + int const s_d_x, + Twb const* __restrict__ w, + int const s_d_w, + int const s_k_w, + Twb const* __restrict__ b, + int const s_d_b, + float* __restrict__ state, + int const s_state_c, + int const s_state_k, + int const kernel_size, + int const k_minus_1, + int const d_channels) +{ + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c >= d_channels) return; + + // 更新状态 + float* row = state + c * s_state_c; + for (int j = 0; j < kernel_size - 1; ++j) { + row[j * s_state_k] = row[(j + 1) * s_state_k]; + } + row[(kernel_size - 1) * s_state_k] = static_cast(x_t[c * s_d_x]); + // 计算输出 + Tdata sum = Tdata(0); + for (int i = 0; i < kernel_size; ++i) { + const int wi = c * s_d_w + i * s_k_w; + const int si = c * s_state_c + i * s_state_k; + sum += Tdata(w[wi]) * Tdata(state[si]); + } + sum += Tdata(b[c * s_d_b]); + + y[c * s_d_y] = sum; +} diff --git a/llama.cu/src/op/conv1d.rs b/llama.cu/src/op/conv1d.rs new file mode 100644 index 00000000..acdc3655 --- /dev/null +++ b/llama.cu/src/op/conv1d.rs @@ -0,0 +1,285 @@ +use super::{Handle, ModuleKey, Operator, cuda_type, gcd}; +use crate::utils::{destruct, dims, offset_ptr, strides}; +use cuda::{Stream, VirByte, params}; +use nn::{Tensor, digit_layout::DigitLayout}; +use std::ffi::{c_int, c_uint}; + +pub struct Conv1d; + +impl Operator for Conv1d { + fn launch<'a, const N: usize>( + handle: &mut Handle, + _arg: Option, + inputs: impl IntoIterator>, + outputs: impl IntoIterator>, + stream: &Stream, + ) { + destruct!([x, w, b] = inputs); + destruct!([y] = outputs); + + dims!([n_x, d_x] = x); + dims!([d_w, k] = w); + dims!([d_b] = b); + dims!([n_y, d_y] = y); + + assert_eq!(n_x, n_y); + assert_eq!(d_x, d_w); + assert_eq!(d_x, d_b); + assert_eq!(d_x, d_y); + let kernel_size = k as c_int; + let padding = kernel_size - 1; + + let dt = x.dt(); + assert_eq!(y.dt(), dt); + let dt_wb = w.dt(); + assert_eq!(b.dt(), dt_wb); + + strides!([s_n_x, s_d_x] = x); + strides!([s_n_y, s_d_y] = y); + strides!([s_d_w, s_k_w] = w); + strides!([s_d_b] = b); + + let unit = dt.nbytes() as isize; + let unit_w = w.dt().nbytes() as isize; + assert_eq!(s_d_x, unit); + assert_eq!(s_d_y, unit); + assert_eq!(s_k_w, unit_w); + assert_eq!(s_d_b, unit_w); + + let max_threads_block = handle.ctx.dev().block_limit().max_threads; + + let key = [ + ModuleKey::Text("conv1d"), + ModuleKey::Type(dt), + ModuleKey::Type(dt_wb), + ] + .into_iter(); + let module = handle.compile(key.collect(), || code(dt, dt_wb)); + let kernel = module.get_kernel(c"conv1d"); + + let params = params![ + offset_ptr(&y), + (s_n_y / unit) as c_int, + (s_d_y / unit) as c_int, + offset_ptr(&x), + (s_n_x / unit) as c_int, + (s_d_x / unit) as c_int, + offset_ptr(&w), + (s_d_w / unit_w) as c_int, + (s_k_w / unit_w) as c_int, + offset_ptr(&b), + (s_d_b / unit_w) as c_int, + kernel_size, + padding + ]; + + let block = gcd(max_threads_block, n_y); + + stream.launch( + &kernel, + ( + ((n_y / block) as c_uint, (d_b as c_uint)), + block as c_uint, + 0, + ), + ¶ms.to_ptrs(), + ); + } +} + +fn code(dt: DigitLayout, dt_wb: DigitLayout) -> String { + const CODE: &str = include_str!("conv1d.cuh"); + let dt = cuda_type(dt); + let dt_wb = cuda_type(dt_wb); + format!( + r#"{CODE} + +extern "C" __global__ void conv1d( + {dt} *__restrict__ y, + int const s_n_y, + int const s_d_y, + {dt} const *__restrict__ x, + int const s_n_x, + int const s_d_x, + {dt_wb} const *__restrict__ w, + int const s_d_w, + int const s_k_w, + {dt_wb} const *__restrict__ b, + int const s_d_b, + int const kernel_size, + int const padding) {{ + kernel(y, s_n_y, s_d_y, + x, s_n_x, s_d_x, + w, s_d_w, s_k_w, + b, s_d_b, + kernel_size, padding); +}} + +extern "C" __global__ void conv1d_step( + {dt} *__restrict__ y, + int const s_n_y, + int const s_d_y, + {dt} const *__restrict__ x, + int const s_n_x, + int const s_d_x, + {dt_wb} const *__restrict__ w, + int const s_d_w, + int const s_k_w, + {dt_wb} const *__restrict__ b, + int const s_d_b, + float* __restrict__ state, + int const s_state_c, + int const s_state_k, + int const kernel_size, + int const k_minus_1, + int const d_channels) {{ + kernel_step<{dt}, {dt_wb}>( + y, s_n_y, s_d_y, + x, s_n_x, s_d_x, + w, s_d_w, s_k_w, + b, s_d_b, + state, s_state_c, s_state_k, + kernel_size, k_minus_1, d_channels); +}} + +extern "C" __global__ void conv1d_write_state( + {dt} const *__restrict__ x, + int const s_n_x, + int const s_d_x, + float* __restrict__ state, + int const s_state_c, + int const s_state_k, + int const n, + int const d, + int const k) {{ + kernel_write_state<{dt}>(x, s_n_x, s_d_x, state, s_state_c, s_state_k, n, d, k); +}} +"#, + ) +} + +pub struct CausalConv1dStep; + +impl Operator for CausalConv1dStep { + fn launch<'a, const N: usize>( + handle: &mut Handle, + _arg: Option, + inputs: impl IntoIterator>, + outputs: impl IntoIterator>, + stream: &Stream, + ) { + destruct!([x_t, w, b, state] = inputs); + destruct!([y_t] = outputs); + + dims!([n_x, d_x] = x_t); + dims!([d_w, k] = w); + dims!([d_b] = b); + dims!([n_y, d_y] = y_t); + dims!([d_state, k_state] = state); + + assert_eq!(n_x, 1); + assert_eq!(n_y, 1); + assert_eq!(d_x, d_w); + assert_eq!(d_x, d_b); + assert_eq!(d_x, d_y); + assert_eq!(d_state, d_x); + assert_eq!(k_state, k); + + let dt = x_t.dt(); + assert_eq!(y_t.dt(), dt); + let dt_wb = w.dt(); + assert_eq!(b.dt(), dt_wb); + assert_eq!(state.dt(), nn::digit_layout::types::F32); + + strides!([s_n_x, s_d_x] = x_t); + strides!([s_n_y, s_d_y] = y_t); + strides!([s_d_w, s_k_w] = w); + strides!([s_d_b] = b); + strides!([s_state_c, s_state_k] = state); + + let max_threads_block = handle.ctx.dev().block_limit().max_threads; + let key = [ + ModuleKey::Text("conv1d"), + ModuleKey::Type(dt), + ModuleKey::Type(dt_wb), + ] + .into_iter(); + let module = handle.compile(key.collect(), || code(dt, dt_wb)); + let kernel = module.get_kernel(c"conv1d_step"); + + let params = params![ + offset_ptr(&y_t), + (s_n_y / dt.nbytes() as isize) as c_int, + (s_d_y / dt.nbytes() as isize) as c_int, + offset_ptr(&x_t), + (s_n_x / dt.nbytes() as isize) as c_int, + (s_d_x / dt.nbytes() as isize) as c_int, + offset_ptr(&w), + (s_d_w / dt_wb.nbytes() as isize) as c_int, + (s_k_w / dt_wb.nbytes() as isize) as c_int, + offset_ptr(&b), + (s_d_b / dt_wb.nbytes() as isize) as c_int, + offset_ptr(&state), + (s_state_c / 4) as c_int, + (s_state_k / 4) as c_int, + k as c_int, + (k - 1) as c_int, + d_y as c_int + ]; + + let block = gcd(max_threads_block, d_y); + let grid_x: usize = d_y.div_ceil(block); + // (grid.y, grid.x) + stream.launch( + &kernel, + ((1 as c_uint, grid_x as c_uint), block as c_uint, 0), + ¶ms.to_ptrs(), + ); + } +} + +pub struct Conv1dWriteStateStep; + +impl Operator for Conv1dWriteStateStep { + fn launch<'a, const N: usize>( + handle: &mut Handle, + _arg: Option, + inputs: impl IntoIterator>, + _outputs: impl IntoIterator>, + stream: &Stream, + ) { + destruct!([x, state] = inputs); + let dt = x.dt(); + dims!([n, d] = x); + dims!([d_state, k] = state); + assert_eq!(d_state, d); + assert_eq!(state.dt(), nn::digit_layout::types::F32); + strides!([s_n_x, s_d_x] = x); + strides!([s_state_c, s_state_k] = state); + + let key = [ModuleKey::Text("conv1d"), ModuleKey::Type(dt)].into_iter(); + let max_threads_block = handle.ctx.dev().block_limit().max_threads; + let module = handle.compile(key.collect(), || code(dt, dt)); + let kernel = module.get_kernel(c"conv1d_write_state"); + + let params = params![ + offset_ptr(&x), + (s_n_x / dt.nbytes() as isize) as c_int, + (s_d_x / dt.nbytes() as isize) as c_int, + offset_ptr(&state), + (s_state_c / 4) as c_int, + (s_state_k / 4) as c_int, + n as c_int, + d as c_int, + k as c_int + ]; + + let block = gcd(max_threads_block, d); + let grid_x: usize = d.div_ceil(block); + stream.launch( + &kernel, + ((1 as c_uint, grid_x as c_uint), block as c_uint, 0), + ¶ms.to_ptrs(), + ); + } +} diff --git a/llama.cu/src/op/element_mul.cuh b/llama.cu/src/op/element_mul.cuh new file mode 100644 index 00000000..7fd55135 --- /dev/null +++ b/llama.cu/src/op/element_mul.cuh @@ -0,0 +1,19 @@ +template +__device__ inline void kernel( + T* __restrict__ y, + int const stride_y_n, + T const* __restrict__ a, + int const stride_a_n, + T const* __restrict__ b, + int const stride_b_n, + int const n, + int const d) { + int n_idx = blockIdx.x; + int d_blk = blockIdx.y; + int d_idx = threadIdx.x + d_blk * blockDim.x; + if (d_idx >= d) return; + T const* a_row = a + n_idx * stride_a_n; + T const* b_row = b + n_idx * stride_b_n; + T* y_row = y + n_idx * stride_y_n; + y_row[d_idx] = a_row[d_idx] * b_row[d_idx]; +} diff --git a/llama.cu/src/op/element_mul.rs b/llama.cu/src/op/element_mul.rs new file mode 100644 index 00000000..e1c0dee4 --- /dev/null +++ b/llama.cu/src/op/element_mul.rs @@ -0,0 +1,85 @@ +use super::{Handle, ModuleKey, Operator, cuda_type, gcd}; +use crate::utils::{destruct, dims, offset_ptr, strides}; +use cuda::{Stream, VirByte, params}; +use nn::{Tensor, digit_layout::DigitLayout}; +use std::ffi::{c_int, c_uint}; + +pub struct ElementMul; + +impl Operator for ElementMul { + fn launch<'a, const N: usize>( + handle: &mut Handle, + _arg: Option, + inputs: impl IntoIterator>, + outputs: impl IntoIterator>, + stream: &Stream, + ) { + destruct!([a, b] = inputs); + destruct!([y] = outputs); + + dims!([n, d] = a); + dims!([n2, d2] = b); + dims!([n3, d3] = y); + assert_eq!(n, n2); + assert_eq!(d, d2); + assert_eq!(n, n3); + assert_eq!(d, d3); + + let dt = a.dt(); + assert_eq!(b.dt(), dt); + assert_eq!(y.dt(), dt); + + strides!([s_n_a, s_d_a] = a); + strides!([s_n_b, s_d_b] = b); + strides!([s_n_y, s_d_y] = y); + let unit = dt.nbytes() as isize; + assert_eq!(s_d_a, unit); + assert_eq!(s_d_b, unit); + assert_eq!(s_d_y, unit); + + let max_threads_block = handle.ctx.dev().block_limit().max_threads; + let key = [ModuleKey::Text("element-mul"), ModuleKey::Type(dt)].into_iter(); + let module = handle.compile(key.collect(), || code(dt)); + let kernel = module.get_kernel(c"element_mul"); + + let params = params![ + offset_ptr(&y), + (s_n_y / unit) as c_int, + offset_ptr(&a), + (s_n_a / unit) as c_int, + offset_ptr(&b), + (s_n_b / unit) as c_int, + n as c_int, + d as c_int + ]; + + let block = gcd(max_threads_block, d); + stream.launch( + &kernel, + (((d / block) as c_uint, n as c_uint), block as c_uint, 0), + ¶ms.to_ptrs(), + ); + } +} + +fn code(dt: DigitLayout) -> String { + const CODE: &str = include_str!("element_mul.cuh"); + let dt = cuda_type(dt); + format!( + r#"{CODE} + +extern "C" __global__ void element_mul( + {dt} *__restrict__ y, + int const stride_y_n, + {dt} const *__restrict__ a, + int const stride_a_n, + {dt} const *__restrict__ b, + int const stride_b_n, + int const n, + int const d +){{ + kernel(y, stride_y_n, a, stride_a_n, b, stride_b_n, n, d); +}} +"# + ) +} diff --git a/llama.cu/src/op/mod.rs b/llama.cu/src/op/mod.rs index 28394832..efd87ae7 100644 --- a/llama.cu/src/op/mod.rs +++ b/llama.cu/src/op/mod.rs @@ -1,6 +1,8 @@ mod add; #[cfg(nccl)] mod all_reduce; +pub mod conv1d; +mod element_mul; mod embedding; mod fast_embedding; mod gelu; @@ -9,6 +11,8 @@ mod linear; mod mrope; mod rms_norm; mod rope; +pub mod scan; +mod silu; mod swiglu; use crate::handle::Handle; @@ -22,6 +26,8 @@ pub mod random_sample; #[cfg(nccl)] pub use all_reduce::AllReduce; +pub use conv1d::{CausalConv1dStep, Conv1d}; +pub use element_mul::ElementMul; pub use embedding::Embedding; pub use fast_embedding::FastEmbedding; pub use gelu::Gelu; @@ -30,6 +36,8 @@ pub use linear::Linear; pub use mrope::MRope; pub use rms_norm::RmsNorm; pub use rope::Rope; +pub use scan::SelectiveScanWithWriteback; +pub use silu::Silu; pub use swiglu::Swiglu; pub trait Operator { diff --git a/llama.cu/src/op/scan.cuh b/llama.cu/src/op/scan.cuh new file mode 100644 index 00000000..4496579a --- /dev/null +++ b/llama.cu/src/op/scan.cuh @@ -0,0 +1,94 @@ +template +static __device__ void kernel_with_writeback( + Tdata *__restrict__ out, + int const stride_out_seq, + int const stride_out_d_in, + Tdata const *__restrict__ u, + int const stride_u_seq, + int const stride_u_d_in, + Tdata const *__restrict__ delta, + int const stride_delta_seq, + int const stride_delta_d_in, + Tad const *__restrict__ A, + int const stride_A_d_in, + int const stride_A_n, + Tdata const *__restrict__ B, + int const stride_B_seq, + int const stride_B_n, + Tdata const *__restrict__ C, + int const stride_C_seq, + int const stride_C_n, + Tad const *__restrict__ D, + int const stride_D_d_in, + int const seq_len, + int const d_in, + int const n, + float *__restrict__ state_out, + int const stride_state_d, + int const stride_state_n) +{ + assert(gridDim.x == 5120 && gridDim.x == d_in); + const int i_d_in = blockIdx.x + blockIdx.y * gridDim.x; + assert(blockIdx.x >= 0 && blockIdx.x < 5120); + assert(blockIdx.y == 0); + assert(threadIdx.x == 0); + if (i_d_in >= d_in) return; + + extern __shared__ unsigned char smem_raw[]; + float *x = reinterpret_cast(smem_raw); + + // 更新 cache + if (seq_len > 1) { + for (int i = 0; i < n; ++i) + x[i] = 0.f; + } else if (seq_len == 1) { + for (int i = 0; i < n; ++i) + x[i] = state_out[i_d_in * stride_state_d + i * stride_state_n]; + } else { + assert(false); + } + + for (int t = 0; t < seq_len; ++t) { + const auto i_u = t * stride_u_seq + i_d_in * stride_u_d_in; + const auto i_delta = t * stride_delta_seq + i_d_in * stride_delta_d_in; + const auto i_B = t * stride_B_seq; + const auto i_C = t * stride_C_seq; + + const Tdata u_ = u[i_u]; + const float dx = static_cast(delta[i_delta]); + const float delta_f = (dx <= 20.f) ? log1pf(expf(dx)) : dx; + + float y = 0.f; + for (int i = 0; i < n; ++i) + { + const auto i_A = i_d_in * stride_A_d_in + i * stride_A_n; + const auto i_B_ = i_B + i * stride_B_n; + const float a = static_cast(A[i_A]); + const float b = static_cast(B[i_B_]); + float xi = x[i]; + const float u_f = static_cast(u_); + const float deltaA = expf(delta_f * a); + const float deltaB_u = delta_f * b * u_f; + // state equation + xi = deltaA * xi + deltaB_u; + x[i] = xi; + + const auto i_C_ = i_C + i * stride_C_n; + // output equation + y += static_cast(C[i_C_]) * x[i]; + } + + const auto i_D = i_d_in * stride_D_d_in; + y += static_cast(D[i_D]) * static_cast(u_); + + const auto i_out = t * stride_out_seq + i_d_in * stride_out_d_in; + out[i_out] = static_cast(y); + } + + // 写回 cache + if (state_out != nullptr) { + for (int i = 0; i < n; ++i) { + state_out[i_d_in * stride_state_d + i * stride_state_n] = x[i]; + } + } +} diff --git a/llama.cu/src/op/scan.rs b/llama.cu/src/op/scan.rs new file mode 100644 index 00000000..332a823c --- /dev/null +++ b/llama.cu/src/op/scan.rs @@ -0,0 +1,162 @@ +use super::{Handle, ModuleKey, Operator, cuda_type}; +use crate::utils::{destruct, dims, offset_ptr, strides}; +use cuda::{Stream, VirByte, params}; +use nn::{Tensor, digit_layout::DigitLayout}; +use std::ffi::{c_int, c_uint}; + +pub struct SelectiveScanWithWriteback; +#[allow(non_snake_case)] +impl Operator for SelectiveScanWithWriteback { + fn launch<'a, const N: usize>( + handle: &mut Handle, + _arg: Option, + inputs: impl IntoIterator>, // [u,delta,A,B,C,D,state] + outputs: impl IntoIterator>, // [out] + stream: &Stream, + ) { + destruct!([u, delta, A, B, C, D, state] = inputs); + destruct!([out] = outputs); + + dims!([n, d] = u); + dims!([n_delta, d_delta] = delta); + dims!([d_A, n_state] = A); + dims!([n_B, n_state_B] = B); + dims!([n_C, n_state_C] = C); + dims!([d_D] = D); + dims!([n_out, d_out] = out); + dims!([d_state, n_state2] = state); + assert_eq!(n, n_delta); + assert_eq!(d, d_delta); + assert_eq!(n, n_B); + assert_eq!(n, n_C); + assert_eq!(d, d_A); + assert_eq!(d, d_D); + assert_eq!(n, n_out); + assert_eq!(d, d_out); + assert_eq!(d_state, d); + assert_eq!(n_state, n_state_B); + assert_eq!(n_state, n_state_C); + assert_eq!(n_state, n_state2); + assert_eq!(state.dt(), nn::digit_layout::types::F32); + assert_eq!(state.dt(), A.dt()); + assert_eq!(d_state, 5120); + assert_eq!(n_state2, 16); + assert_eq!(d, 5120); + + let dt = u.dt(); + let dt_AD = A.dt(); + strides!([s_n_u, s_d_u] = u); + strides!([s_n_delta, s_d_delta] = delta); + strides!([s_d_A, s_s_A] = A); + strides!([s_n_B, s_s_B] = B); + strides!([s_n_C, s_s_C] = C); + strides!([s_d_D] = D); + strides!([s_n_out, s_d_out] = out); + strides!([s_state_d, s_state_n] = state); + assert_eq!(s_state_d / 4, n_state as isize); + assert_eq!(s_state_n / 4, 1); + + let unit = dt.nbytes() as isize; + let unit_AD = dt_AD.nbytes() as isize; + let key = [ + ModuleKey::Text("scan"), + ModuleKey::Type(dt), + ModuleKey::Type(dt_AD), + ModuleKey::Size(n_state), + ] + .into_iter(); + let module = handle.compile(key.collect(), || code(dt, dt_AD, n_state)); + let kernel = module.get_kernel(c"scan_with_writeback"); + + let params = params![ + offset_ptr(&out), + (s_n_out / unit) as c_int, + (s_d_out / unit) as c_int, + offset_ptr(&u), + (s_n_u / unit) as c_int, + (s_d_u / unit) as c_int, + offset_ptr(&delta), + (s_n_delta / unit) as c_int, + (s_d_delta / unit) as c_int, + offset_ptr(&A), + (s_d_A / unit_AD) as c_int, + (s_s_A / unit_AD) as c_int, + offset_ptr(&B), + (s_n_B / unit) as c_int, + (s_s_B / unit) as c_int, + offset_ptr(&C), + (s_n_C / unit) as c_int, + (s_s_C / unit) as c_int, + offset_ptr(&D), + (s_d_D / unit_AD) as c_int, + n as c_int, + d as c_int, + n_state as c_int, + offset_ptr(&state), + (s_state_d / 4) as c_int, + (s_state_n / 4) as c_int + ]; + + let block = 1u32; + let shared_mem_bytes = n_state * 4usize; + // (grid.y, grid.x) + stream.launch( + &kernel, + ( + (1 as c_uint, d as c_uint), + block as c_uint, + shared_mem_bytes, + ), + ¶ms.to_ptrs(), + ); + } +} + +#[allow(non_snake_case)] +pub(crate) fn code(dt: DigitLayout, dt_AD: DigitLayout, _state: usize) -> String { + const CODE: &str = include_str!("scan.cuh"); + let dt = cuda_type(dt); + let dt_AD = cuda_type(dt_AD); + format!( + r#"{CODE} + +extern "C" __global__ void scan_with_writeback( + {dt} *__restrict__ out, + int const s_n_out, + int const s_d_out, + {dt} const *__restrict__ u, + int const s_n_u, + int const s_d_u, + {dt} const *__restrict__ delta, + int const s_n_delta, + int const s_d_delta, + {dt_AD} const *__restrict__ A, + int const s_d_A, + int const s_s_A, + {dt} const *__restrict__ B, + int const s_n_B, + int const s_s_B, + {dt} const *__restrict__ C, + int const s_n_C, + int const s_s_C, + {dt_AD} const *__restrict__ D, + int const s_d_D, + int const n, + int const d, + int const state, + float *__restrict__ state_out, + int const s_state_d, + int const s_state_n) {{ + kernel_with_writeback(out, s_n_out, s_d_out, + u, s_n_u, s_d_u, + delta, s_n_delta, s_d_delta, + A, s_d_A, s_s_A, + B, s_n_B, s_s_B, + C, s_n_C, s_s_C, + D, s_d_D, + n, d, state, + state_out, s_state_d, s_state_n); +}} +"# + ) +} diff --git a/llama.cu/src/op/silu.cuh b/llama.cu/src/op/silu.cuh new file mode 100644 index 00000000..66aff6e8 --- /dev/null +++ b/llama.cu/src/op/silu.cuh @@ -0,0 +1,16 @@ +static __forceinline__ __device__ float sigmoid(float x) { + return fdividef(1, 1 + expf(-x)); +} + +template +static __device__ void kernel( + Tdata *__restrict__ out, + int const stride_out, + Tdata const *__restrict__ up_, + int const stride_up) { + auto n = blockIdx.x * blockDim.x + threadIdx.x, + i = blockIdx.y * stride_out + n, + k = blockIdx.y * stride_up + n; + float up = up_[k]; + out[i] = Tdata(up * sigmoid(up)); +} diff --git a/llama.cu/src/op/silu.rs b/llama.cu/src/op/silu.rs new file mode 100644 index 00000000..960fe7fb --- /dev/null +++ b/llama.cu/src/op/silu.rs @@ -0,0 +1,85 @@ +use super::{Handle, ModuleKey, Operator, cuda_type, gcd}; +use crate::utils::{destruct, dims, offset_ptr, strides}; +use cuda::{Stream, VirByte, params}; +use nn::{Tensor, digit_layout::DigitLayout}; +use std::ffi::{c_int, c_uint}; + +pub struct Silu; + +impl Operator for Silu { + fn launch<'a, const N: usize>( + handle: &mut Handle, + arg: Option, + inputs: impl IntoIterator>, + outputs: impl IntoIterator>, + stream: &Stream, + ) { + assert!(arg.is_none()); + + destruct!([up] = inputs); + destruct!([out] = outputs); + + // 检查维度 + dims!([n, d] = up); + dims!([n2, d2] = out); + + assert_eq!(n, n2); + assert_eq!(d, d2); + + // 检查类型 + let dt = up.dt(); + assert_eq!(out.dt(), dt); + + // 获取 stride + strides!([s_n_up, s_d_up] = up); + strides!([s_n_out, s_d_out] = out); + + // 确保 stride 符合期望 + let unit = dt.nbytes() as isize; + assert_eq!(s_d_up, unit); + assert_eq!(s_d_out, unit); + + // 获取最大线程数 + let max_threads_block = handle.ctx.dev().block_limit().max_threads; + + // 编译内核 + let key = [ModuleKey::Text("silu"), ModuleKey::Type(dt)].into_iter(); + let module = handle.compile(key.collect(), || code(dt)); + let kernel = module.get_kernel(c"silu"); + + // 准备参数 + let params = params![ + offset_ptr(&out), + (s_n_out / unit) as c_int, + offset_ptr(&up), + (s_n_up / unit) as c_int + ]; + + // 计算线程块配置 + let block = gcd(max_threads_block, d); + + stream.launch( + &kernel, + ((n as c_uint, (d / block) as c_uint), block as c_uint, 0), + ¶ms.to_ptrs(), + ); + } +} + +fn code(dt: DigitLayout) -> String { + const CODE: &str = include_str!("silu.cuh"); + let dt = cuda_type(dt); + + format!( + r#"{CODE} + +extern "C" __global__ void silu( + {dt} *__restrict__ out, + int const stride_out, + {dt} const *__restrict__ up, + int const stride_up +){{ + kernel(out, stride_out, up, stride_up); +}}"# + ) +} diff --git a/test_ppl.py b/test_ppl.py new file mode 100644 index 00000000..00d210df --- /dev/null +++ b/test_ppl.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +import argparse +import math +import os +import sys +import time +from typing import List, Tuple + +import requests + +try: + from datasets import load_dataset # type: ignore + import torch + from transformers import MambaForCausalLM, AutoTokenizer + from tqdm import tqdm +except Exception as e: # pragma: no cover + print("[ERROR] 请先安装依赖: pip install -U datasets requests torch transformers tqdm", file=sys.stderr) + raise + + +def fetch_logprobs( + api_base: str, + model: str, + text: str, + timeout: float, +) -> Tuple[List[float], int]: + """调用 /completions,返回 prompt 的 token logprobs 列表与 token 数。 + + 要求服务端实现 echo=true 且支持 logprobs>=1,max_tokens=0。 + 如果服务端未返回 logprobs,将抛出 RuntimeError。 + """ + url = api_base.rstrip("/") + "/completions" + payload = { + "model": model, + "prompt": text, + "max_tokens": 0, + "echo": True, + "logprobs": 1, + "stream": False, + } + resp = requests.post(url, json=payload, timeout=timeout) + if resp.status_code != 200: + raise RuntimeError(f"HTTP {resp.status_code}: {resp.text[:512]}") + data = resp.json() + try: + choice = data["choices"][0] + lp = choice["logprobs"] + except Exception: + raise RuntimeError( + "服务未返回 logprobs 字段。请确保服务支持 /completions 的 echo 与 logprobs。" + ) + token_logprobs = lp.get("token_logprobs") + if token_logprobs is None: + raise RuntimeError( + "logprobs.token_logprobs 为空。请在服务端实现 token 级对数概率返回。" + ) + # 过滤 None(例如特殊符号),仅聚合有效对数概率 + valid_lps = [x for x in token_logprobs if x is not None] + return valid_lps, len(valid_lps) + + +def compute_ppl_on_dataset( + api_base: str, + model: str, + dataset_name: str, + config: str, + split: str, + max_samples: int, + timeout: float, +) -> float: + ds = load_dataset(dataset_name, config, split=split) + + total_nll = 0.0 + total_tokens = 0 + processed = 0 + + for item in ds: + # 不同配置字段名不同:wikitext-*-raw-v1 通常字段为 'text' + text = (item.get("text") or "").strip() + if not text: + continue + try: + token_logprobs, num_toks = fetch_logprobs( + api_base, model, text, timeout) + except Exception as e: + # 将最早的错误直接抛出,便于用户修复服务端 + raise + + if num_toks == 0: + continue + total_tokens += num_toks + total_nll += -sum(token_logprobs) # NLL = -log p(token) + processed += 1 + + if 0 < max_samples <= processed: + break + + if total_tokens == 0: + raise RuntimeError("未获得任何 token 的 logprobs,无法计算 PPL。") + + avg_nll = total_nll / total_tokens + ppl = math.exp(avg_nll) + return ppl + + +def compute_ppl_pytorch(model, tokenizer, texts, max_length=1024, max_samples=None): + """使用 PyTorch Mamba 计算 PPL""" + model.eval() + total_loss = 0.0 + total_tokens = 0 + processed = 0 + + with torch.no_grad(): + for text in tqdm(texts, desc="Computing PyTorch PPL"): + # 分词 + inputs = tokenizer(text, return_tensors="pt", + truncation=True, max_length=max_length) + input_ids = inputs["input_ids"].to(model.device) + + if input_ids.size(1) < 2: # 需要至少2个token来计算loss + continue + + # 前向传播 + outputs = model(input_ids, labels=input_ids) + loss = outputs.loss + + # 累积loss和token数 + total_loss += loss.item() * (input_ids.size(1) - 1) # 减1因为最后一个token没有target + total_tokens += input_ids.size(1) - 1 + processed += 1 + + # 检查是否达到最大样本数 + if max_samples and max_samples > 0 and processed >= max_samples: + break + + # 计算PPL + avg_loss = total_loss / total_tokens + ppl = math.exp(avg_loss) + return ppl + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate PPL via OpenAI /completions (echo+logprobs)") + parser.add_argument("--api-base", type=str, + default="http://127.0.0.1:8080", help="服务地址") + parser.add_argument("--model", type=str, required=True, + help="模型名称(/models 返回的 id)") + parser.add_argument("--dataset", type=str, + default="wikitext", help="HF datasets 名称") + parser.add_argument("--config", type=str, + default="wikitext-2-raw-v1", help="HF datasets 配置名") + parser.add_argument("--split", type=str, default="test", + help="数据集划分,例如 test/validation") + parser.add_argument("--max-samples", type=int, + default=500, help="最多评估多少条样本(>0 生效)") + parser.add_argument("--timeout", type=float, + default=120.0, help="HTTP 请求超时秒数") + + args = parser.parse_args() + + # 计算原版本PPL + print("=== 计算原版本 PPL ===") + t0 = time.time() + ppl_original = compute_ppl_on_dataset( + api_base=args.api_base, + model=args.model, + dataset_name=args.dataset, + config=args.config, + split=args.split, + max_samples=args.max_samples, + timeout=args.timeout, + ) + dt_original = time.time() - t0 + print(f"原版本 PPL = {ppl_original:.4f} (time: {dt_original:.2f}s)") + + # 计算PyTorch版本PPL + print("\n=== 计算 PyTorch 版本 PPL ===") + print("Loading PyTorch model and tokenizer...") + tokenizer = AutoTokenizer.from_pretrained( + "/home/cearx/qy/model/mamba-2.8b-hf") + model = MambaForCausalLM.from_pretrained( + "/home/cearx/qy/model/mamba-2.8b-hf", device_map="cuda") + + # 加载相同的数据集 + print("Loading dataset...") + dataset = load_dataset(args.dataset, args.config, split=args.split) + texts = [item["text"] for item in dataset if item["text"].strip()] + + t1 = time.time() + ppl_pytorch = compute_ppl_pytorch( + model, tokenizer, texts, max_samples=args.max_samples) + dt_pytorch = time.time() - t1 + print(f"PyTorch 版本 PPL = {ppl_pytorch:.4f} (time: {dt_pytorch:.2f}s)") + + # 计算差异 + print("\n=== PPL 对比结果 ===") + print(f"原版本 PPL: {ppl_original:.4f}") + print(f"PyTorch PPL: {ppl_pytorch:.4f}") + diff_abs = abs(ppl_original - ppl_pytorch) + diff_rel = diff_abs / ppl_pytorch * 100 + print(f"绝对差异: {diff_abs:.4f}") + print(f"相对差异: {diff_rel:.2f}%") + print(f"原版本时间: {dt_original:.2f}s") + print(f"PyTorch时间: {dt_pytorch:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/xtask/src/main.rs b/xtask/src/main.rs index ffb624b3..a743f3d5 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -29,6 +29,7 @@ fn main() { Generate(args) => args.generate(), Chat(args) => args.chat(), Service(args) => args.service(), + MambaService(args) => args.mamba_service(), Bench(args) => args.bench(), } } @@ -49,6 +50,8 @@ enum Commands { Chat(chat::ChatArgs), /// web service Service(service::ServiceArgs), + /// mamba web service (with logprobs support) + MambaService(service::MambaServiceArgs), /// batched benchmark Bench(bench::BenchArgs), } diff --git a/xtask/src/service/error.rs b/xtask/src/service/error.rs index a65e0121..fd2517cd 100644 --- a/xtask/src/service/error.rs +++ b/xtask/src/service/error.rs @@ -8,6 +8,7 @@ pub(crate) enum Error { NotFound(NotFoundError), MsgNotSupported(MsgNotSupportedError), ModelNotFound(String), + InternalError(String), } #[derive(Serialize, Debug)] @@ -42,6 +43,7 @@ impl Error { Self::NotFound(..) => StatusCode::NOT_FOUND, Self::MsgNotSupported(..) => StatusCode::BAD_REQUEST, Self::ModelNotFound(..) => StatusCode::NOT_FOUND, + Self::InternalError(..) => StatusCode::INTERNAL_SERVER_ERROR, } } @@ -52,6 +54,7 @@ impl Error { Self::NotFound(e) => serde_json::to_string(&e).unwrap(), Self::MsgNotSupported(e) => serde_json::to_string(&e).unwrap(), Self::ModelNotFound(model) => format!("Model not found: {model}"), + Self::InternalError(msg) => format!("Internal server error: {msg}"), } } } @@ -63,6 +66,7 @@ impl fmt::Display for Error { Error::NotFound(e) => write!(f, "Not Found: {} {}", e.method, e.uri), Error::MsgNotSupported(e) => write!(f, "Message type not supported: {:?}", e.message), Error::ModelNotFound(model) => write!(f, "Model not found: {model}"), + Error::InternalError(msg) => write!(f, "Internal error: {msg}"), } } } diff --git a/xtask/src/service/mod.rs b/xtask/src/service/mod.rs index cd58bbf2..a80b6338 100644 --- a/xtask/src/service/mod.rs +++ b/xtask/src/service/mod.rs @@ -8,7 +8,10 @@ mod response; use crate::{ parse_gpus, service::{ - openai::{chat_completion_response, chat_completion_response_stream, completion_response}, + openai::{ + chat_completion_response, chat_completion_response_stream, completion_response, + completion_response_with_logprobs, + }, response::text_stream, }, }; @@ -28,6 +31,7 @@ use openai::create_models; use openai_struct::{CreateChatCompletionRequest, CreateCompletionRequest}; use response::error; use response::json; +use serde_json::Value; use std::{ collections::HashMap, sync::atomic::{AtomicUsize, Ordering::SeqCst}, @@ -68,6 +72,29 @@ pub struct ServiceArgs { think: bool, } +#[derive(Args)] +pub struct MambaServiceArgs { + model: String, + + #[clap(short, long)] + port: u16, + #[clap(long)] + no_cuda_graph: bool, + + #[clap(long)] + name: Option, + #[clap(long)] + gpus: Option, + #[clap(long)] + max_tokens: Option, + #[clap(long)] + temperature: Option, + #[clap(long)] + top_p: Option, + #[clap(long)] + repetition_penalty: Option, +} + #[derive(serde::Deserialize, Debug)] pub struct ModelConfig { pub path: String, @@ -191,7 +218,8 @@ impl HyperService> for App { let models = self.0.clone(); Box::pin(async move { let whole_body = req.collect().await?.to_bytes(); - let req: CreateCompletionRequest = match serde_json::from_slice(&whole_body) { + let mut req: CreateCompletionRequest = match serde_json::from_slice(&whole_body) + { Ok(req) => req, Err(e) => return Ok(error(Error::WrongJson(e))), }; @@ -204,14 +232,25 @@ impl HyperService> for App { } }; let stream = req.stream.unwrap_or(true); + + // 测 ppl + let echo = req.echo.unwrap_or(false); + let logprobs = req.logprobs; + let max_tokens = req.max_tokens.unwrap_or(16); + let prompt_text = match &req.prompt { + Value::String(s) => s.clone(), + Value::Array(arr) => arr + .iter() + .filter_map(|v| v.as_str()) + .collect::>() + .join("\n"), + _ => String::new(), + }; + let model = match models.get(&model_name) { Some(model) => model, None => return Ok(error(Error::ModelNotFound(model_name))), }; - let mut receiver = match model.complete(req) { - Ok(receiver) => receiver, - Err(e) => return Ok(error(e)), - }; static ID: AtomicUsize = AtomicUsize::new(0); let id = ID.fetch_add(1, SeqCst); @@ -220,6 +259,19 @@ impl HyperService> for App { .unwrap() .as_secs() as i32; + // 特殊处理:PPL 测试 (max_tokens=0, echo=true, logprobs 存在) + // 修改策略:让PPL请求也走正常推理流程,但强制max_tokens=1来触发logprobs计算 + let is_ppl_request = max_tokens == 0 && echo && logprobs.is_some(); + if is_ppl_request { + // 修改请求参数:设置max_tokens=1来触发推理,但稍后我们只返回prompt的logprobs + req.max_tokens = Some(1); + } + + let mut receiver = match model.complete(req) { + Ok(receiver) => receiver, + Err(e) => return Ok(error(e)), + }; + if stream { return Ok(text_stream(UnboundedReceiverStream::new(receiver).map( move |output| { @@ -259,7 +311,55 @@ impl HyperService> for App { } } - let response = completion_response(id, created, model_name, content_, reason_); + // 检查是否是PPL请求,如果是,返回logprobs响应 + if is_ppl_request { + // 尝试从全局存储获取logprobs + if let Some(stored_logprobs) = llama_cu::take_stored_logprobs() { + // 分词以获取token信息 + let tokens = model.tokenize(&prompt_text); + let mut token_strings = Vec::new(); + let mut text_offsets = Vec::new(); + let mut current_offset = 0i32; + + for &token in &tokens { + let token_text = model.decode(&[token]); + token_strings.push(token_text.clone()); + text_offsets.push(current_offset); + current_offset += token_text.len() as i32; + } + + // 只返回prompt部分的logprobs(不包括生成的内容) + let prompt_logprobs = if stored_logprobs.len() >= tokens.len() { + stored_logprobs[..tokens.len()].to_vec() + } else { + stored_logprobs + }; + + let response = completion_response_with_logprobs( + id, + created, + model_name, + prompt_text.clone(), + reason_, + Some(prompt_logprobs), + Some(token_strings), + Some(text_offsets), + ); + return Ok(json(response)); + } else { + return Ok(error(Error::InternalError( + "No logprobs were computed during inference".to_string(), + ))); + } + } + + // 正常响应 + let text_body = if echo { + format!("{prompt_text}{content_}") + } else { + content_ + }; + let response = completion_response(id, created, model_name, text_body, reason_); Ok(json(response)) }) } @@ -363,6 +463,61 @@ impl HyperService> for App { } } +impl MambaServiceArgs { + pub fn mamba_service(self) { + let Self { + model, + port, + no_cuda_graph, + name, + gpus, + max_tokens, + temperature, + top_p, + repetition_penalty, + } = self; + + let model_name = name.unwrap_or_else(|| { + std::path::Path::new(&model) + .file_stem() + .unwrap() + .to_str() + .unwrap() + .to_string() + }); + + info!("启动 Mamba 服务: {}", model_name); + info!("模型路径: {}", model); + info!("端口: {}", port); + + // 创建 Mamba 模型配置 + let model_config = ModelConfig { + path: model, + gpus: Some(parse_gpus(gpus.as_deref())), + max_tokens, + temperature, + top_p, + repetition_penalty, + think: Some(false), // Mamba 暂时不支持 think 模式 + blacklist: None, + }; + + info!("{}: {:?}", model_name, model_config); + + // 创建模型和服务 - Mamba 不支持 CUDA Graph,强制禁用 + let (model, service) = Model::new_mamba(model_config, false); + let model = Arc::new(model); + let handles = vec![(model.clone(), service)]; + let models = [(model_name, model)].into(); + + // 启动服务 + tokio::runtime::Runtime::new() + .unwrap() + .block_on(start_infer_service(models, handles, port)) + .unwrap() + } +} + #[cfg(test)] mod blacklist_integration_test; #[cfg(test)] diff --git a/xtask/src/service/model.rs b/xtask/src/service/model.rs index f8242637..d7de91bc 100644 --- a/xtask/src/service/model.rs +++ b/xtask/src/service/model.rs @@ -118,6 +118,63 @@ impl Model { (model, service) } + pub fn new_mamba(config: ModelConfig, use_cuda_graph: bool) -> (Self, Service) { + let ModelConfig { + path, + gpus, + max_tokens, + temperature, + top_p, + repetition_penalty, + think, + blacklist, + } = config; + + // Mamba 不支持 CUDA Graph,强制禁用 + let mut service = Service::new_mamba(path, &gpus.unwrap_or(Box::new([0])), false); + progress_bar(&mut service); + + let think = if think.unwrap_or(false) { + let &[think] = &*service.terminal().encode("") else { + unreachable!() + }; + let &[_think] = &*service.terminal().encode("") else { + unreachable!() + }; + [think, _think] + } else { + [utok::MAX; 2] + }; + + let blacklist = blacklist + .unwrap_or_default() + .into_iter() + .map(|s| s.to_lowercase()) + .collect::>(); + + let model = Model { + max_tokens: max_tokens.unwrap_or(2 << 10), + sampling: SampleArgs::new( + temperature.unwrap_or(0.), + top_p.unwrap_or(1.), + usize::MAX, + repetition_penalty.unwrap_or(1.), + ) + .unwrap(), + think, + terminal: service.terminal().clone(), + sessions: Default::default(), + cache_manager: Default::default(), + blacklist_checker: if blacklist.is_empty() { + None + } else { + Some(BlacklistChecker::new(blacklist)) + }, + }; + + (model, service) + } + pub fn serve(&self, service: &mut Service) { let [think, _think] = self.think; loop { @@ -375,6 +432,18 @@ impl Model { let (sender, receiver) = mpsc::unbounded_channel(); + // 当 max_tokens == 0 时,不进行任何解码,直接返回 Finish,避免底层 "Cannot decode 0 step" 断言失败 + if max_tokens == 0 { + // 直接结束,不注册 session、不触发解码 + sender + .send(Output::Finish { + reason: FinishReason::Stop, + num_tokens: [tokens.len(), tokens.len()], + }) + .ok(); + return Ok(receiver); + } + let (id, tokens) = self.cache_manager.lock().unwrap().send( &self.terminal, tokens, @@ -409,4 +478,72 @@ impl Model { .map(|checker| checker.contains_word(text)) .unwrap_or(false) } + + /// Tokenize text using the model's tokenizer + pub fn tokenize(&self, text: &str) -> Vec { + self.terminal.tokenize(text) + } + + /// Decode tokens to text using the model's tokenizer + pub fn decode(&self, tokens: &[utok]) -> String { + let mut buf = TextBuf::new(); + self.terminal.decode(tokens, &mut buf) + } + + /// Compute logprobs for the given text (for PPL evaluation) + /// Returns (token_logprobs, token_strings, text_offsets) + pub async fn compute_logprobs( + &self, + text: &str, + ) -> Result<(Vec, Vec, Vec), Box> { + // 分词 + let tokens = self.tokenize(text); + if tokens.is_empty() { + return Ok((Vec::new(), Vec::new(), Vec::new())); + } + + // 准备返回值 + let mut token_logprobs = Vec::new(); + let mut token_strings = Vec::new(); + let mut text_offsets = Vec::new(); + let mut current_offset = 0i32; + + // 为每个位置计算 token 信息 + for &token in &tokens { + let token_text = self.decode(&[token]); + token_strings.push(token_text.clone()); + text_offsets.push(current_offset); + current_offset += token_text.len() as i32; + } + + // 一次性计算所有位置的 logprobs + // 这里我们创建一个特殊的推理请求来获取 logprobs + token_logprobs = self.compute_sequence_logprobs(&tokens).await?; + + Ok((token_logprobs, token_strings, text_offsets)) + } + + /// Compute logprobs for an entire token sequence using real model inference + async fn compute_sequence_logprobs( + &self, + tokens: &[utok], + ) -> Result, Box> { + // 检查是否有存储的 logprobs(由 Mamba 引擎计算) + if let Some(stored_logprobs) = llama_cu::take_stored_logprobs() { + // 返回存储的真实 logprobs + let num_tokens = tokens.len(); + if stored_logprobs.len() >= num_tokens { + Ok(stored_logprobs[..num_tokens].to_vec()) + } else { + Ok(stored_logprobs) + } + } else { + // 如果没有存储的 logprobs,说明引擎没有计算 + // 这种情况下返回一个提示性错误 + Err( + "No logprobs available. The Mamba engine should compute logprobs during inference." + .into(), + ) + } + } } diff --git a/xtask/src/service/openai.rs b/xtask/src/service/openai.rs index fa7bcd60..d80694e3 100644 --- a/xtask/src/service/openai.rs +++ b/xtask/src/service/openai.rs @@ -105,6 +105,19 @@ pub(crate) fn completion_response( model: String, text: String, finish_reason: Option, +) -> CreateCompletionResponse { + completion_response_with_logprobs(id, created, model, text, finish_reason, None, None, None) +} + +pub(crate) fn completion_response_with_logprobs( + id: usize, + created: i32, + model: String, + text: String, + finish_reason: Option, + token_logprobs: Option>, + tokens: Option>, + text_offset: Option>, ) -> CreateCompletionResponse { let finish_reason = match finish_reason { Some(FinishReason::Stop) => "stop", @@ -120,9 +133,9 @@ pub(crate) fn completion_response( finish_reason, index: 0, logprobs: CreateCompletionResponseLogprobs { - text_offset: None, - token_logprobs: None, - tokens: None, + text_offset, + token_logprobs, + tokens, top_logprobs: None, }, }];