Skip to content

Commit 374c27f

Browse files
committed
fix: 实现 cpu 推理
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 55c8997 commit 374c27f

File tree

8 files changed

+72
-22
lines changed

8 files changed

+72
-22
lines changed

Cargo.lock

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ test-utils.path = "test-utils"
1717
ggus = { git = "https://github.com/YdrMaster/gguf", rev = "e64d758" }
1818
ggml-quants = { git = "https://github.com/YdrMaster/gguf", rev = "e64d758" }
1919
ndarray-layout = { git = "https://github.com/YdrMaster/ndarray-layout", rev = "5c6b969" }
20-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "9b5c6b9", default-features = false }
20+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "656e7f7", default-features = false }
2121

2222
memmap2 = "0.9"

gguf/src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
mod chat_template;
22
mod tokenizer;
33

4+
use ggml_quants::digit_layout::DigitLayout;
45
use ggus::{
5-
ggml_quants::digit_layout::DigitLayout, GGuf, GGufError, GGufFileName, GGufMetaDataValueType,
6-
GGufMetaKV, GGufMetaMap, GGufReadError, GENERAL_ALIGNMENT,
6+
GGuf, GGufError, GGufFileName, GGufMetaDataValueType, GGufMetaKV, GGufMetaMap, GGufReadError,
7+
GENERAL_ALIGNMENT,
78
};
89
use memmap2::Mmap;
910
use std::{collections::HashMap, fmt::Debug, fs::File, path::Path};
1011

12+
pub use ggus::{ggml_quants, GGufMetaError, GGufMetaMapExt};
1113
pub use tokenizer::Tokenizer;
1214

1315
/// 从指定文件的路径出发,映射所有分片文件。

models/llama/common-cpu/src/lib.rs

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ use memmap2::Mmap;
77
use operators::{
88
common_cpu::{Cpu, ThisThread},
99
random_sample::{common_cpu::Operator as CpuOp, KVPair, SampleArgs},
10-
QueueOf,
10+
ByteOf, QueueOf,
1111
};
12-
use std::slice::from_raw_parts_mut;
12+
use std::{ops::Deref, slice::from_raw_parts_mut};
1313
use tensor::{ArrayLayout, BigEndian, Tensor};
1414

1515
pub struct Llama {
@@ -62,7 +62,7 @@ impl Llama {
6262
let mut embd_buf = vec![0u8; embd.shape().iter().product::<usize>() * ele];
6363
let mut logits_buf = vec![0u8; logits.shape().iter().product::<usize>() * ele];
6464

65-
let d = embd.shape()[1];
65+
let d = embd.shape()[1] * ele;
6666
for (i, &tok) in input.iter().enumerate() {
6767
embd_buf[i * d..][..d].copy_from_slice(&self.token_embed[tok as usize * d..][..d]);
6868
}
@@ -132,6 +132,13 @@ impl llama::Operators for Operators {
132132
type AttnKVCached = op!(attention_kv_cached);
133133
type Mlp = op!(mlp);
134134
type Rearrange = op!(rearrange);
135+
136+
fn debug<T>(tensor: &Tensor<T>)
137+
where
138+
T: Deref<Target = [ByteOf<Self::Hardware>]>,
139+
{
140+
println!("{tensor}");
141+
}
135142
}
136143

137144
struct Weights {
@@ -174,14 +181,19 @@ impl WeightLoader for Weights {
174181
}
175182

176183
#[test]
177-
fn test_load() {
178-
use gguf::GGufModel;
179-
use std::{io::Write, slice::from_raw_parts};
184+
fn test_infer() {
185+
use gguf::{GGufMetaMapExt, GGufModel};
186+
use std::{
187+
io::Write,
188+
slice::from_raw_parts,
189+
time::{Duration, Instant},
190+
};
180191

181192
let Some(shards) = test_utils::map_gguf_files() else {
182193
return;
183194
};
184195
let gguf = GGufModel::read(shards.iter().map(|s| &**s));
196+
let eos = gguf.tokenizer_ggml_eos_token_id().unwrap();
185197
let tokenizer = gguf.tokenizer();
186198
let llama =
187199
LlamaStorage::from_gguf(&gguf).map(&mut |s| unsafe { from_raw_parts(s.as_ptr(), s.len()) });
@@ -194,14 +206,50 @@ fn test_load() {
194206
let mut cache_buf = vec![0u8; cache.shape().iter().product::<usize>() * size_of::<f16>()];
195207

196208
let mut prompt = "Once upon a time,".to_string();
209+
210+
print!("{prompt}");
211+
std::io::stdout().flush().unwrap();
212+
197213
let mut tokens = tokenizer.encode(&prompt);
198-
while !tokens.contains(&2) {
199-
let next = llama.infer(&tokens, &mut cache_buf, 0);
200-
tokens = vec![next];
214+
let num_prompt_tokens = tokens.len();
215+
216+
let mut prefill = Duration::ZERO;
217+
let mut decode = Duration::ZERO;
218+
219+
let mut pos = 0;
220+
loop {
221+
let time = Instant::now();
222+
let next = llama.infer(&tokens, &mut cache_buf, pos);
223+
let time = time.elapsed();
224+
225+
if prefill.is_zero() {
226+
prefill = time;
227+
} else {
228+
decode += time;
229+
}
230+
231+
pos += tokens.len();
232+
if next == eos {
233+
break;
234+
}
201235

202236
let piece = tokenizer.decode(next);
203237
print!("{piece}");
204238
std::io::stdout().flush().unwrap();
205239
prompt.push_str(&piece);
240+
tokens = vec![next];
241+
}
242+
243+
println!();
244+
println!();
245+
print_time("total", prefill + decode, pos);
246+
print_time("prefill", prefill, num_prompt_tokens);
247+
print_time("decode", decode, pos - num_prompt_tokens);
248+
249+
fn print_time(name: &str, time: Duration, n: usize) {
250+
println!(
251+
"{name} : {time:?} for {n} tokens, avg: {:?} per token",
252+
time.div_f64(n as _)
253+
);
206254
}
207255
}

models/llama/common/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
77
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
88

99
[dependencies]
10-
ggus.workspace = true
1110
gguf.workspace = true
1211
operators.workspace = true
1312
tensor.workspace = true

models/llama/common/src/compute.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::{args::Args, LlamaMeta};
2-
use ggus::ggml_quants::digit_layout::types as ty;
2+
use gguf::ggml_quants::digit_layout::types as ty;
33
use itertools::izip;
44
use operators::{
55
attention_kv_cached::AttnKVCached,
@@ -21,6 +21,10 @@ pub trait Operators {
2121
type AttnKVCached: AttnKVCached<Self::Hardware>;
2222
type Mlp: Mlp<Self::Hardware>;
2323
type Rearrange: Rearrange<Self::Hardware>;
24+
25+
fn debug<T>(tensor: &Tensor<T>)
26+
where
27+
T: Deref<Target = [ByteOf<Self::Hardware>]>;
2428
}
2529

2630
pub enum BlkWeight {
@@ -255,8 +259,8 @@ where
255259
let x_ = unsafe { x.map_slice_static() };
256260
self.rms_norm(&mut x, &x_, &w, workspace, queue_alloc)?;
257261

258-
let lm_head = self.weights.output(queue);
259-
self.mat_mul(&mut logits, 0., &x, &lm_head, 1., workspace, queue_alloc)
262+
let output = self.weights.output(queue);
263+
self.mat_mul(&mut logits, 0., &x, &output, 1., workspace, queue_alloc)
260264
}
261265
}
262266

models/llama/common/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@ mod compute;
33
mod random_sample;
44
mod storage;
55

6-
use ggus::ggml_quants::digit_layout::DigitLayout;
76
use tensor::Tensor;
87

98
pub use args::{Args as LlamaArgs, Request as LlamaRequest};
109
pub use compute::{BlkWeight, LlamaBlks, Operators, WeightLoader};
11-
pub use ggus::ggml_quants::digit_layout::types as primitive;
10+
pub use gguf::ggml_quants::digit_layout::{types as primitive, DigitLayout};
1211
pub use random_sample::RandomSample;
1312
pub use storage::{BlkStorage as LlamaBlkStorage, Storage as LlamaStorage};
1413

models/llama/common/src/storage.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use crate::LlamaMeta;
2-
use gguf::GGufModel;
3-
use ggus::{GGufMetaError, GGufMetaMapExt};
2+
use gguf::{GGufMetaError, GGufMetaMapExt, GGufModel};
43

54
#[derive(Clone)]
65
pub struct Storage<T> {

0 commit comments

Comments
 (0)