|
| 1 | +use crate::exec::group::{ModelGroupConfig, ModelGroupMamba}; |
| 2 | +use crate::exec::mamba_cache::MambaCache; |
| 3 | +use crate::exec::output_head::OutputHead; |
| 4 | +use crate::exec::sample_manager::SampleManager; |
| 5 | +use crate::memory::MemPages; |
| 6 | +use crate::op::random_sample::{KVPair, SampleArgs}; |
| 7 | +use crate::utils::{self, meta}; |
| 8 | +use crate::{handle::Handle, model::map_files}; |
| 9 | +use cuda::Device; |
| 10 | +use ggus::GGufMetaMapExt; |
| 11 | +use nn::Distribution; |
| 12 | +use std::env; |
| 13 | +use tokeneer::Bpe; |
| 14 | + |
| 15 | +#[allow(dead_code)] |
| 16 | +pub fn mamba_infer( |
| 17 | + model_path: std::path::PathBuf, |
| 18 | + text: &str, |
| 19 | + use_cuda_graph: bool, |
| 20 | +) -> (Vec<u8>, [usize; 2]) { |
| 21 | + use crate::model::GGufModel; |
| 22 | + // 初始化 CUDA |
| 23 | + assert!(cuda::init().is_ok()); |
| 24 | + |
| 25 | + // 加载模型 |
| 26 | + let maps = map_files(model_path); |
| 27 | + let gguf = GGufModel::read(maps.iter().map(|x| &**x)); |
| 28 | + let tokenizer = Bpe::from_gguf(&gguf); |
| 29 | + let mut tokens = tokenizer.encode(text); |
| 30 | + |
| 31 | + let n_tok = tokens.len(); |
| 32 | + |
| 33 | + // 取出输出头用于 logits 计算 |
| 34 | + let mut mamba = gguf.mamba(); |
| 35 | + let output_head_nn = mamba |
| 36 | + .output_head |
| 37 | + .take() |
| 38 | + .expect("mamba model missing output_head"); |
| 39 | + |
| 40 | + let n_layer: usize = meta![gguf => llm_block_count]; |
| 41 | + let d_inner: usize = 5120; // TODO: ggus |
| 42 | + let d_conv: usize = 4; // kernel size |
| 43 | + let d_state: usize = 16; // ssm state size |
| 44 | + |
| 45 | + // 单卡 |
| 46 | + let device = Device::new(0); |
| 47 | + device.retain_primary().apply(|ctx| { |
| 48 | + let mut handle = Handle::new(ctx); |
| 49 | + let dist = Distribution { |
| 50 | + start: 0, |
| 51 | + len: 1, |
| 52 | + total: 1, |
| 53 | + }; |
| 54 | + let mut models = ModelGroupMamba::new( |
| 55 | + mamba, |
| 56 | + dist, |
| 57 | + None, |
| 58 | + ModelGroupConfig { |
| 59 | + static_model_keys: [n_tok], |
| 60 | + dyn_cache_size: 1, |
| 61 | + use_cuda_graph, |
| 62 | + }, |
| 63 | + &mut handle, |
| 64 | + None, |
| 65 | + ); |
| 66 | + |
| 67 | + // 组件:输出头与采样器 |
| 68 | + let stream = ctx.stream(); |
| 69 | + let mut output_head = OutputHead::new(output_head_nn, ctx); |
| 70 | + |
| 71 | + // 读取词表大小构建采样器与 eos |
| 72 | + let eos: tokeneer::utok = meta![gguf => tokenizer_ggml_eos_token_id]; |
| 73 | + let nvoc = output_head.nvoc(); |
| 74 | + let mut sample_manager = SampleManager::new(nvoc, eos, ctx); |
| 75 | + |
| 76 | + // 初始化 MambaCache |
| 77 | + let mut pages = MemPages::new(device); |
| 78 | + let mut mcache = MambaCache::new(n_layer, d_inner, d_conv, d_state, &mut pages); |
| 79 | + |
| 80 | + // Prefill |
| 81 | + let (key, _tok_buf) = |
| 82 | + models.load_inputs_mamba_prefill(&mut handle, tokens.len(), &tokens, &stream); |
| 83 | + |
| 84 | + let mut x = models.launch_mamba(key, &mut mcache, &mut handle, &stream); |
| 85 | + |
| 86 | + let last_idx: [tokeneer::utok; 1] = [(tokens.len() - 1) as tokeneer::utok]; |
| 87 | + let logits_prefill_last = output_head.launch(x.clone(), &last_idx, &mut handle, &stream); |
| 88 | + |
| 89 | + let logits_prefill_last_vir = logits_prefill_last.as_ref().map(|mem| mem.as_ptr().cast()); |
| 90 | + utils::fmt(&logits_prefill_last_vir, stream.ctx()); |
| 91 | + // check prefill logits |
| 92 | + |
| 93 | + let mut next_id: tokeneer::utok; |
| 94 | + { |
| 95 | + let mut input = stream.malloc::<tokeneer::utok>(tokens.len()); |
| 96 | + stream.memcpy_h2d(&mut input, &tokens); |
| 97 | + let cfg0 = vec![( |
| 98 | + crate::batch::SessionId(0), |
| 99 | + crate::batch::SampleInfo { |
| 100 | + args: SampleArgs::new(0.8, 0.95, 50, 1.2).unwrap(), |
| 101 | + input_idx: tokens.len(), |
| 102 | + decode_len: tokens.len(), |
| 103 | + }, |
| 104 | + )]; |
| 105 | + let kv_pairs0 = sample_manager.sample(logits_prefill_last, &input, &cfg0, &stream); |
| 106 | + stream.free(input); |
| 107 | + let mut host_kv0 = vec![KVPair::ZERO; 1]; |
| 108 | + stream.memcpy_d2h(&mut host_kv0, &kv_pairs0).free(kv_pairs0); |
| 109 | + next_id = host_kv0[0].idx as tokeneer::utok; |
| 110 | + } |
| 111 | + let mut generated: Vec<tokeneer::utok> = Vec::new(); |
| 112 | + if next_id != eos { |
| 113 | + tokens.push(next_id); |
| 114 | + generated.push(next_id); |
| 115 | + let (key, _tok_buf) = models.load_input_mamba_decode(&mut handle, next_id, &stream); |
| 116 | + x = models.launch_mamba(key, &mut mcache, &mut handle, &stream); |
| 117 | + } |
| 118 | + |
| 119 | + let max_decode_steps: usize = env::var("MAMBA_STEPS") |
| 120 | + .ok() |
| 121 | + .and_then(|s| s.parse().ok()) |
| 122 | + .unwrap_or(100); |
| 123 | + for _step in 1..max_decode_steps { |
| 124 | + let out_idx: [tokeneer::utok; 1] = [0]; |
| 125 | + |
| 126 | + let logits = output_head.launch(x.clone(), &out_idx, &mut handle, &stream); |
| 127 | + |
| 128 | + let mut input = stream.malloc::<tokeneer::utok>(tokens.len()); |
| 129 | + stream.memcpy_h2d(&mut input, &tokens); |
| 130 | + let cfg = vec![( |
| 131 | + crate::batch::SessionId(0), |
| 132 | + crate::batch::SampleInfo { |
| 133 | + args: SampleArgs::new(0.8, 0.95, 50, 1.2).unwrap(), |
| 134 | + input_idx: tokens.len(), |
| 135 | + decode_len: tokens.len(), |
| 136 | + }, |
| 137 | + )]; |
| 138 | + let kv_pairs = sample_manager.sample(logits, &input, &cfg, &stream); |
| 139 | + stream.free(input); |
| 140 | + let mut host_kv = vec![KVPair::ZERO; 1]; |
| 141 | + stream.memcpy_d2h(&mut host_kv, &kv_pairs).free(kv_pairs); |
| 142 | + next_id = host_kv[0].idx as tokeneer::utok; |
| 143 | + |
| 144 | + if next_id == eos { |
| 145 | + break; |
| 146 | + } |
| 147 | + |
| 148 | + tokens.push(next_id); |
| 149 | + generated.push(next_id); |
| 150 | + let (key, _tok_buf) = models.load_input_mamba_decode(&mut handle, next_id, &stream); |
| 151 | + x = models.launch_mamba(key, &mut mcache, &mut handle, &stream); |
| 152 | + } |
| 153 | + |
| 154 | + println!("tokens = {:?}", tokens); |
| 155 | + let mut text_buf = tokeneer::TextBuf::new(); |
| 156 | + let s = tokenizer.decode(&generated, &mut text_buf); |
| 157 | + let buf = s.into_bytes(); |
| 158 | + |
| 159 | + let shape = <[usize; 2]>::try_from(x.shape().to_vec()).unwrap(); |
| 160 | + (buf, shape) |
| 161 | + }) |
| 162 | +} |
| 163 | + |
| 164 | +// #[cfg(test)] |
| 165 | +// mod tests { |
| 166 | +// use super::*; |
| 167 | +// use std::path::PathBuf; |
| 168 | + |
| 169 | +// #[test] |
| 170 | +// fn test_mamba_infer_decode() { |
| 171 | +// let model = PathBuf::from("../model/Mamba_adf32-2.8B-hf-v1.0-F16.gguf"); |
| 172 | +// let prompt = "Once upon a time,"; |
| 173 | +// let (bytes, _shape) = mamba_infer(model, prompt, false); |
| 174 | +// let text = String::from_utf8_lossy(&bytes); |
| 175 | +// println!("prompt = {}", prompt); |
| 176 | +// println!("mamba infer text = {}", text); |
| 177 | +// } |
| 178 | +// } |
0 commit comments