Skip to content

Commit bed406f

Browse files
committed
style(test): 整理测试公共代码
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 4a11840 commit bed406f

File tree

2 files changed

+103
-84
lines changed

2 files changed

+103
-84
lines changed

test-utils/src/lib.rs

Lines changed: 32 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1+
#[cfg(feature = "llama")]
2+
mod llama;
3+
4+
#[cfg(feature = "llama")]
5+
pub use llama::{test_infer_paralle, Task, WorkerSeed};
6+
17
use gguf::{
28
ext::{utok, Mmap},
39
map_files, GGufMetaMapExt, GGufModel, Message, Tokenizer,
410
};
511
use std::{
612
env::{var, var_os},
713
fmt,
8-
iter::zip,
914
path::{Path, PathBuf},
1015
str::FromStr,
11-
sync::{
12-
mpsc::{self, Sender},
13-
Once,
14-
},
16+
sync::Once,
1517
time::{Duration, Instant},
1618
};
1719

@@ -26,13 +28,27 @@ pub struct Inference {
2628
pub max_steps: usize,
2729
}
2830

31+
mod env {
32+
pub const TEST_MODEL: &str = "TEST_MODEL";
33+
pub const TEST_IMAGE: &str = "TEST_IMAGE";
34+
pub const DEVICES: &str = "DEVICES";
35+
pub const PROMPT: &str = "PROMPT";
36+
pub const AS_USER: &str = "AS_USER";
37+
pub const TEMPERATURE: &str = "TEMPERATURE";
38+
pub const TOP_P: &str = "TOP_P";
39+
pub const TOP_K: &str = "TOP_K";
40+
pub const MAX_STEPS: &str = "MAX_STEPS";
41+
pub const ROLL_CACHE_SIZE: &str = "ROLL_CACHE_SIZE";
42+
}
43+
use env::*;
44+
2945
impl Inference {
3046
pub fn load() -> Option<Self> {
3147
static ONCE: Once = Once::new();
3248
ONCE.call_once(env_logger::init);
3349

34-
let Some(path) = var_os("TEST_MODEL") else {
35-
println!("TEST_MODEL not set");
50+
let Some(path) = var_os(TEST_MODEL) else {
51+
println!("{TEST_MODEL} not set");
3652
return None;
3753
};
3854
let path = Path::new(&path);
@@ -50,26 +66,26 @@ impl Inference {
5066

5167
Some(Self {
5268
model: map_files(path),
53-
devices: var("DEVICES").ok(),
54-
prompt: var("PROMPT").unwrap_or_else(|_| String::from("Once upon a time,")),
55-
as_user: var("AS_USER").ok().is_some_and(|s| !s.is_empty()),
56-
temperature: parse("TEMPERATURE", 0.),
57-
top_p: parse("TOP_P", 1.),
58-
top_k: parse("TOP_K", usize::MAX),
59-
max_steps: parse("MAX_STEPS", usize::MAX),
69+
devices: var(DEVICES).ok(),
70+
prompt: var(PROMPT).unwrap_or_else(|_| String::from("Once upon a time,")),
71+
as_user: var(AS_USER).ok().is_some_and(|s| !s.is_empty()),
72+
temperature: parse(TEMPERATURE, 0.),
73+
top_p: parse(TOP_P, 1.),
74+
top_k: parse(TOP_K, usize::MAX),
75+
max_steps: parse(MAX_STEPS, usize::MAX),
6076
})
6177
}
6278
}
6379

6480
pub fn load_roll_cache_size() -> usize {
65-
var("ROLL_CACHE_SIZE")
81+
var(ROLL_CACHE_SIZE)
6682
.ok()
6783
.and_then(|s| s.parse().ok())
6884
.unwrap_or(usize::MAX)
6985
}
7086

7187
pub fn image() -> Option<PathBuf> {
72-
var_os("TEST_IMAGE").map(PathBuf::from)
88+
var_os(TEST_IMAGE).map(PathBuf::from)
7389
}
7490

7591
pub struct TokenizerAndPrompt {
@@ -179,71 +195,3 @@ pub fn test_infer(
179195
]
180196
}
181197
}
182-
183-
#[cfg(feature = "llama")]
184-
pub fn test_infer_paralle<'w>(
185-
model: &llama::LlamaStorage<&'w [u8]>,
186-
senders: Box<[mpsc::Sender<Task>]>,
187-
eos: utok,
188-
tokenizer: Tokenizer,
189-
prompt: &str,
190-
max_steps: usize,
191-
) {
192-
use tensor::Blob;
193-
194-
let (next, next_recv) = mpsc::channel();
195-
test_infer(eos, tokenizer, prompt, max_steps, |input, pos| {
196-
let mut embd = model.meta.embd(input.len()).map(Blob::new).take();
197-
198-
let d = embd.len() / input.len();
199-
for (i, &tok) in input.iter().enumerate() {
200-
embd[i * d..][..d].copy_from_slice(&model.token_embd[tok as usize * d..][..d]);
201-
}
202-
203-
for sender in &senders {
204-
sender
205-
.send(Task {
206-
nt: input.len(),
207-
pos,
208-
embd: embd.as_ptr(),
209-
next: next.clone(),
210-
})
211-
.unwrap()
212-
}
213-
next_recv.recv().unwrap()
214-
});
215-
}
216-
217-
pub struct Task {
218-
pub nt: usize,
219-
pub pos: usize,
220-
pub embd: *const u8,
221-
pub next: mpsc::Sender<utok>,
222-
}
223-
224-
unsafe impl Send for Task {}
225-
226-
pub struct WorkerSeed<N> {
227-
pub tasks: mpsc::Receiver<Task>,
228-
pub node: N,
229-
}
230-
231-
impl<N> WorkerSeed<N> {
232-
pub fn new(nodes: Vec<N>) -> (Vec<Self>, Vec<Sender<Task>>) {
233-
let n = nodes.len();
234-
235-
let mut tasks = Vec::with_capacity(n);
236-
let mut senders = Vec::with_capacity(n);
237-
for _ in 0..n {
238-
let (sender, receiver) = std::sync::mpsc::channel();
239-
tasks.push(receiver);
240-
senders.push(sender);
241-
}
242-
(
243-
zip(nodes, tasks)
244-
.map(|(node, tasks)| Self { node, tasks })
245-
.collect(),
246-
senders,
247-
)
248-
}
249-
}

test-utils/src/llama.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
use crate::test_infer;
2+
use gguf::{ext::utok, Tokenizer};
3+
use llama::LlamaStorage;
4+
use std::{iter::zip, sync::mpsc};
5+
6+
pub fn test_infer_paralle<'w>(
7+
model: &LlamaStorage<&'w [u8]>,
8+
senders: Box<[mpsc::Sender<Task>]>,
9+
eos: utok,
10+
tokenizer: Tokenizer,
11+
prompt: &str,
12+
max_steps: usize,
13+
) {
14+
use tensor::Blob;
15+
16+
let (next, next_recv) = mpsc::channel();
17+
test_infer(eos, tokenizer, prompt, max_steps, |input, pos| {
18+
let mut embd = model.meta.embd(input.len()).map(Blob::new).take();
19+
20+
let d = embd.len() / input.len();
21+
for (i, &tok) in input.iter().enumerate() {
22+
embd[i * d..][..d].copy_from_slice(&model.token_embd[tok as usize * d..][..d]);
23+
}
24+
25+
for sender in &senders {
26+
sender
27+
.send(Task {
28+
nt: input.len(),
29+
pos,
30+
embd: embd.as_ptr(),
31+
next: next.clone(),
32+
})
33+
.unwrap()
34+
}
35+
next_recv.recv().unwrap()
36+
});
37+
}
38+
39+
pub struct Task {
40+
pub nt: usize,
41+
pub pos: usize,
42+
pub embd: *const u8,
43+
pub next: mpsc::Sender<utok>,
44+
}
45+
46+
unsafe impl Send for Task {}
47+
48+
pub struct WorkerSeed<N> {
49+
pub tasks: mpsc::Receiver<Task>,
50+
pub node: N,
51+
}
52+
53+
impl<N> WorkerSeed<N> {
54+
pub fn new(nodes: Vec<N>) -> (Vec<Self>, Vec<mpsc::Sender<Task>>) {
55+
let n = nodes.len();
56+
57+
let mut tasks = Vec::with_capacity(n);
58+
let mut senders = Vec::with_capacity(n);
59+
for _ in 0..n {
60+
let (sender, receiver) = std::sync::mpsc::channel();
61+
tasks.push(receiver);
62+
senders.push(sender);
63+
}
64+
(
65+
zip(nodes, tasks)
66+
.map(|(node, tasks)| Self { node, tasks })
67+
.collect(),
68+
senders,
69+
)
70+
}
71+
}

0 commit comments

Comments
 (0)