Skip to content

Commit 4a11840

Browse files
committed
style(llama): clean up parallel inference
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 37c0d87 commit 4a11840

File tree

7 files changed

+119
-161
lines changed

7 files changed

+119
-161
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ common.path = "common"
2121
gguf.path = "gguf"
2222
tensor.path = "tensor"
2323
causal-lm.path = "causal-lm"
24-
test-utils.path = "test-utils"
24+
test-utils = { path = "test-utils", default-features = false }
2525

2626
ggus = "0.3"
2727
itertools = "0.13"

models/llama/common-cpu/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ llama.path = "../common"
1111
operators = { workspace = true, features = ["common-cpu"] }
1212

1313
[dev-dependencies]
14-
test-utils.workspace = true
14+
test-utils = { workspace = true, features = ["llama"] }
1515
gguf.workspace = true
1616
regex.workspace = true

models/llama/common-cpu/src/infer.rs

Lines changed: 16 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,8 @@ use operators::{
88
Blob,
99
};
1010
use regex::Regex;
11-
use std::{
12-
iter::zip,
13-
ptr::copy_nonoverlapping,
14-
slice::from_raw_parts_mut,
15-
sync::mpsc::{Receiver, Sender},
16-
thread,
17-
};
18-
use test_utils::{Inference, TokenizerAndPrompt};
11+
use std::{iter::zip, ptr::copy_nonoverlapping, slice::from_raw_parts_mut, thread};
12+
use test_utils::{test_infer_paralle, Inference, Task, TokenizerAndPrompt, WorkerSeed};
1913

2014
type Worker<'w> = LlamaWorker<Operators<InprocNode<usize>, AllReduce>, Weights<'w>>;
2115

@@ -48,17 +42,19 @@ fn test_infer() {
4842
let sample_args = SampleArgs::new(temperature, top_p, top_k).expect("invalid sample args");
4943
println!("{sample_args:?}");
5044

51-
let lens = match devices {
52-
Some(devices) => Regex::new(r"\d+")
53-
.unwrap()
54-
.find_iter(&devices)
55-
.map(|c| c.as_str().parse::<usize>().unwrap())
56-
.collect::<Vec<_>>(),
57-
None => vec![1],
58-
};
59-
println!("distribution: {lens:?}");
45+
let lens = devices
46+
.map(|devices| {
47+
Regex::new(r"\d+")
48+
.unwrap()
49+
.find_iter(&devices)
50+
.map(|c| c.as_str().parse().unwrap())
51+
.collect()
52+
})
53+
.unwrap_or_else(|| vec![1]);
6054
let count = lens.iter().sum();
61-
let (seeds, senders) = WorkerSeed::new(lens.len());
55+
println!("distribution: {lens:?}");
56+
57+
let (seeds, senders) = WorkerSeed::new(InprocNode::new(lens.len()));
6258
thread::scope(|s| {
6359
let _workers = zip(lens, seeds)
6460
.enumerate()
@@ -70,7 +66,6 @@ fn test_infer() {
7066
meta.distribute(range.clone(), count);
7167

7268
let model = &model;
73-
7469
Some(s.spawn(move || {
7570
let WorkerSeed { node, tasks } = seed;
7671
let weights = Weights::new(model, range, count);
@@ -141,63 +136,7 @@ fn test_infer() {
141136
})
142137
.collect::<Vec<_>>();
143138

144-
let (next, next_recv) = std::sync::mpsc::channel();
145-
test_utils::test_infer(eos, tokenizer, &prompt, max_steps, |input, pos| {
146-
let mut embd = model.meta.embd(input.len()).map(Blob::new);
147-
148-
let d = embd.get().len() / input.len();
149-
for (i, &tok) in input.iter().enumerate() {
150-
embd.get_mut()[i * d..][..d]
151-
.copy_from_slice(&model.token_embd[tok as usize * d..][..d]);
152-
}
153-
let embd = embd.take();
154-
155-
for sender in &senders {
156-
sender
157-
.send(Task {
158-
nt: input.len(),
159-
pos,
160-
embd: embd.as_ptr(),
161-
next: next.clone(),
162-
})
163-
.unwrap();
164-
}
165-
next_recv.recv().unwrap()
166-
});
167-
168-
drop(senders)
139+
let senders = senders.into_boxed_slice();
140+
test_infer_paralle(&model, senders, eos, tokenizer, &prompt, max_steps)
169141
})
170142
}
171-
172-
struct Task {
173-
nt: usize,
174-
pos: usize,
175-
embd: *const u8,
176-
next: Sender<u32>,
177-
}
178-
179-
unsafe impl Send for Task {}
180-
181-
struct WorkerSeed {
182-
tasks: Receiver<Task>,
183-
node: InprocNode<usize>,
184-
}
185-
186-
impl WorkerSeed {
187-
fn new(n: usize) -> (Vec<Self>, Vec<Sender<Task>>) {
188-
let mut tasks = Vec::with_capacity(n);
189-
let mut senders = Vec::with_capacity(n);
190-
let nodes = InprocNode::new(n);
191-
for _ in 0..n {
192-
let (sender, receiver) = std::sync::mpsc::channel();
193-
tasks.push(receiver);
194-
senders.push(sender);
195-
}
196-
(
197-
zip(nodes, tasks)
198-
.map(|(node, tasks)| Self { node, tasks })
199-
.collect(),
200-
senders,
201-
)
202-
}
203-
}

models/llama/nvidia-gpu/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ build-script-cfg.workspace = true
1515
search-cuda-tools.workspace = true
1616

1717
[dev-dependencies]
18-
test-utils.workspace = true
18+
test-utils = { workspace = true, features = ["llama"] }
1919
gguf.workspace = true
2020
regex.workspace = true

models/llama/nvidia-gpu/src/nccl_parallel.rs

Lines changed: 22 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,15 @@ use operators::{
77
nccl::CommunicatorGroup,
88
nvidia_gpu::NcclNode,
99
random_sample::{KVPair, SampleArgs},
10-
Blob, TopoNode,
10+
TopoNode,
1111
};
1212
use regex::Regex;
1313
use std::{
1414
iter::zip,
1515
slice::{from_raw_parts, from_raw_parts_mut},
16-
sync::mpsc::{Receiver, Sender},
17-
thread, usize,
16+
thread,
1817
};
19-
use test_utils::{Inference, TokenizerAndPrompt};
18+
use test_utils::{test_infer_paralle, Inference, Task, TokenizerAndPrompt, WorkerSeed};
2019

2120
type Worker<'w> = LlamaWorker<Operators<NcclNode, AllReduce>, Weights<'w>>;
2221

@@ -49,21 +48,27 @@ fn test_infer() {
4948
let sample_args = SampleArgs::new(temperature, top_p, top_k).expect("invalid sample args");
5049
println!("{sample_args:?}");
5150

52-
let devices = match devices {
53-
Some(devices) => Regex::new(r"\d+")
54-
.unwrap()
55-
.find_iter(&devices)
56-
.map(|c| c.as_str().parse().unwrap())
57-
.collect::<Vec<_>>(),
58-
None => vec![0],
59-
};
60-
println!("distribution: {devices:?}");
61-
51+
let devices = devices
52+
.map(|devices| {
53+
Regex::new(r"\d+")
54+
.unwrap()
55+
.find_iter(&devices)
56+
.map(|c| c.as_str().parse().unwrap())
57+
.collect()
58+
})
59+
.unwrap_or_else(|| vec![1]);
6260
let lens = vec![1; devices.len()];
6361
let count = devices.len();
62+
println!("distribution: {devices:?}");
6463

6564
let (seeds, senders) = match cuda::init() {
66-
Ok(()) => WorkerSeed::new(&devices),
65+
Ok(()) => WorkerSeed::new(
66+
CommunicatorGroup::new(&devices)
67+
.into_vec()
68+
.into_iter()
69+
.map(|comm| NcclNode::new(comm, Default::default()))
70+
.collect(),
71+
),
6772
Err(NoDevice) => return,
6873
};
6974
thread::scope(|s| {
@@ -77,7 +82,6 @@ fn test_infer() {
7782
meta.distribute(range.clone(), count);
7883

7984
let model = &model;
80-
8185
Some(s.spawn(move || {
8286
let WorkerSeed { node, tasks } = seed;
8387
node.processor().apply(|ctx| {
@@ -163,68 +167,7 @@ fn test_infer() {
163167
})
164168
.collect::<Vec<_>>();
165169

166-
let (next, next_recv) = std::sync::mpsc::channel();
167-
test_utils::test_infer(eos, tokenizer, &prompt, max_steps, |input, pos| {
168-
let mut embd = model.meta.embd(input.len()).map(Blob::new);
169-
170-
let d = embd.get().len() / input.len();
171-
for (i, &tok) in input.iter().enumerate() {
172-
embd.get_mut()[i * d..][..d]
173-
.copy_from_slice(&model.token_embd[tok as usize * d..][..d]);
174-
}
175-
let embd = embd.take();
176-
177-
for sender in &senders {
178-
sender
179-
.send(Task {
180-
nt: input.len(),
181-
pos,
182-
embd: embd.as_ptr(),
183-
next: next.clone(),
184-
})
185-
.unwrap();
186-
}
187-
next_recv.recv().unwrap()
188-
});
189-
190-
drop(senders)
170+
let senders = senders.into_boxed_slice();
171+
test_infer_paralle(&model, senders, eos, tokenizer, &prompt, max_steps)
191172
})
192173
}
193-
194-
struct Task {
195-
nt: usize,
196-
pos: usize,
197-
embd: *const u8,
198-
next: Sender<u32>,
199-
}
200-
201-
unsafe impl Send for Task {}
202-
203-
struct WorkerSeed {
204-
tasks: Receiver<Task>,
205-
node: NcclNode,
206-
}
207-
208-
impl WorkerSeed {
209-
fn new(devices: &[i32]) -> (Vec<Self>, Vec<Sender<Task>>) {
210-
let nodes = CommunicatorGroup::new(devices)
211-
.into_vec()
212-
.into_iter()
213-
.map(|comm| NcclNode::new(comm, Default::default()))
214-
.collect::<Vec<_>>();
215-
let n = nodes.len();
216-
let mut tasks = Vec::with_capacity(n);
217-
let mut senders = Vec::with_capacity(n);
218-
for _ in 0..n {
219-
let (sender, receiver) = std::sync::mpsc::channel();
220-
tasks.push(receiver);
221-
senders.push(sender);
222-
}
223-
(
224-
zip(nodes, tasks)
225-
.map(|(node, tasks)| Self { node, tasks })
226-
.collect(),
227-
senders,
228-
)
229-
}
230-
}

test-utils/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@ version = "0.0.0"
44
edition = "2021"
55
authors = ["YdrMaster <ydrml@hotmail.com>"]
66

7+
[features]
8+
default = ["llama"]
9+
710
[dependencies]
11+
llama = { path = "../models/llama/common", optional = true }
812
gguf.workspace = true
13+
tensor.workspace = true
914
env_logger.workspace = true
1015
cli-table = "0.4.9"

test-utils/src/lib.rs

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@ use gguf::{
55
use std::{
66
env::{var, var_os},
77
fmt,
8+
iter::zip,
89
path::{Path, PathBuf},
910
str::FromStr,
10-
sync::Once,
11+
sync::{
12+
mpsc::{self, Sender},
13+
Once,
14+
},
1115
time::{Duration, Instant},
12-
vec,
1316
};
1417

1518
pub struct Inference {
@@ -176,3 +179,71 @@ pub fn test_infer(
176179
]
177180
}
178181
}
182+
183+
#[cfg(feature = "llama")]
184+
pub fn test_infer_paralle<'w>(
185+
model: &llama::LlamaStorage<&'w [u8]>,
186+
senders: Box<[mpsc::Sender<Task>]>,
187+
eos: utok,
188+
tokenizer: Tokenizer,
189+
prompt: &str,
190+
max_steps: usize,
191+
) {
192+
use tensor::Blob;
193+
194+
let (next, next_recv) = mpsc::channel();
195+
test_infer(eos, tokenizer, prompt, max_steps, |input, pos| {
196+
let mut embd = model.meta.embd(input.len()).map(Blob::new).take();
197+
198+
let d = embd.len() / input.len();
199+
for (i, &tok) in input.iter().enumerate() {
200+
embd[i * d..][..d].copy_from_slice(&model.token_embd[tok as usize * d..][..d]);
201+
}
202+
203+
for sender in &senders {
204+
sender
205+
.send(Task {
206+
nt: input.len(),
207+
pos,
208+
embd: embd.as_ptr(),
209+
next: next.clone(),
210+
})
211+
.unwrap()
212+
}
213+
next_recv.recv().unwrap()
214+
});
215+
}
216+
217+
pub struct Task {
218+
pub nt: usize,
219+
pub pos: usize,
220+
pub embd: *const u8,
221+
pub next: mpsc::Sender<utok>,
222+
}
223+
224+
unsafe impl Send for Task {}
225+
226+
pub struct WorkerSeed<N> {
227+
pub tasks: mpsc::Receiver<Task>,
228+
pub node: N,
229+
}
230+
231+
impl<N> WorkerSeed<N> {
232+
pub fn new(nodes: Vec<N>) -> (Vec<Self>, Vec<Sender<Task>>) {
233+
let n = nodes.len();
234+
235+
let mut tasks = Vec::with_capacity(n);
236+
let mut senders = Vec::with_capacity(n);
237+
for _ in 0..n {
238+
let (sender, receiver) = std::sync::mpsc::channel();
239+
tasks.push(receiver);
240+
senders.push(sender);
241+
}
242+
(
243+
zip(nodes, tasks)
244+
.map(|(node, tasks)| Self { node, tasks })
245+
.collect(),
246+
senders,
247+
)
248+
}
249+
}

0 commit comments

Comments
 (0)