Skip to content

Commit 2f894fa

Browse files
committed
test(infini): infini 推理测试改为分布式版本
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent b4652ec commit 2f894fa

File tree

3 files changed

+127
-83
lines changed

3 files changed

+127
-83
lines changed

models/llama/infini/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ search-infini-tools.workspace = true
1717
[dev-dependencies]
1818
test-utils.workspace = true
1919
gguf.workspace = true
20+
regex.workspace = true

models/llama/infini/src/infer.rs

Lines changed: 125 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,32 @@
11
use crate::{Operators, RandomSample, Weights};
22
use gguf::GGufModel;
3-
use llama::{
4-
ext::ggml_quants::f16, LlamaArgs, LlamaMeta, LlamaRequest, LlamaStorage, LlamaWorker, Tensor,
5-
};
3+
use llama::{ext::ggml_quants::f16, LlamaRequest, LlamaStorage, LlamaWorker, Tensor};
64
use operators::{
75
infini_rt::{self, Device, DeviceType::DEVICE_CPU},
86
random_sample::{KVPair, SampleArgs},
7+
TopoNode,
8+
};
9+
use regex::Regex;
10+
use std::{
11+
iter::zip,
12+
slice::{from_raw_parts, from_raw_parts_mut},
13+
thread,
914
};
10-
use std::{slice::from_raw_parts_mut, thread, usize};
11-
use test_utils::{Inference, TokenizerAndPrompt};
15+
use test_utils::{test_infer_paralle, Inference, Task, TokenizerAndPrompt, WorkerSeed};
1216

1317
type Worker<'w> = LlamaWorker<Operators, Weights>;
1418

1519
#[test]
1620
fn test_infer() {
1721
let Some(Inference {
1822
model,
23+
devices,
1924
prompt,
2025
as_user,
2126
temperature,
2227
top_p,
2328
top_k,
2429
max_steps,
25-
..
2630
}) = Inference::load()
2731
else {
2832
return;
@@ -41,83 +45,122 @@ fn test_infer() {
4145
let sample_args = SampleArgs::new(temperature, top_p, top_k).expect("invalid sample args");
4246
println!("{sample_args:?}");
4347

44-
infini_rt::init(DEVICE_CPU);
45-
let device = Device {
46-
ty: DEVICE_CPU,
47-
id: 0,
48-
};
49-
50-
let meta = &model.meta;
51-
let &LlamaMeta {
52-
dt_embd,
53-
nctx,
54-
nvoc,
55-
dh,
56-
..
57-
} = meta;
48+
let devices = devices
49+
.map(|devices| {
50+
Regex::new(r"\d+")
51+
.unwrap()
52+
.find_iter(&devices)
53+
.map(|c| c.as_str().parse().unwrap())
54+
.collect()
55+
})
56+
.unwrap_or_else(|| vec![0]);
57+
let lens = vec![1; devices.len()];
58+
let count = devices.len();
59+
println!("distribution: {devices:?}");
5860

61+
infini_rt::init(DEVICE_CPU);
62+
let (seeds, senders) = WorkerSeed::new(
63+
devices
64+
.into_iter()
65+
.map(|id| Device { ty: DEVICE_CPU, id })
66+
.collect(),
67+
);
5968
thread::scope(|s| {
60-
let sample = s.spawn(move || {
61-
let mut sample = RandomSample::new(&device);
62-
sample.scheme(dt_embd, nvoc).unwrap();
63-
sample
64-
});
65-
let stream = device.stream();
66-
67-
let token_embd = device.from_host(model.token_embd);
68-
let weights = Weights::new(&model, .., 1, &stream);
69-
let mut worker = Worker::new(&device, meta.clone(), weights, true);
70-
let mut cache = meta.kv_cache(nctx).map(|size| stream.malloc::<u8>(size));
71-
let sin_cos = <Operators as llama::Operators>::build_sin_cos(dt_embd, nctx, dh, &stream);
72-
let indices = RandomSample::build_indices(nvoc, &stream);
73-
74-
let sample = sample.join().unwrap();
75-
test_utils::test_infer(eos, tokenizer, &prompt, max_steps, |input, pos| {
76-
let mut embd = meta.embd(input.len()).map(|len| stream.malloc::<u8>(len));
77-
let mut logits = meta.logits(1).map(|len| stream.malloc::<u8>(len));
78-
79-
let d = embd.get().len() / input.len();
80-
for (i, &tok) in input.iter().enumerate() {
81-
stream.memcpy_d2d(
82-
&mut embd.get_mut()[i * d..][..d],
83-
&token_embd[tok as usize * d..][..d],
84-
)
85-
}
86-
87-
worker
88-
.launch(
89-
LlamaArgs {
90-
embd: embd.map_slice_mut(),
91-
logits: logits.map_slice_mut(),
92-
sin_cos: sin_cos.map_slice(),
93-
requests: vec![LlamaRequest {
94-
cache: cache.map_slice_mut(),
95-
seq_len: input.len(),
96-
out_len: 1,
69+
let _workers = zip(lens, seeds)
70+
.enumerate()
71+
.scan(0, |start, (i, (len, seed))| {
72+
let range = *start..*start + len;
73+
*start = range.end;
74+
75+
let mut meta = model.meta.clone();
76+
meta.distribute(range.clone(), count);
77+
78+
let model = &model;
79+
Some(s.spawn(move || {
80+
let WorkerSeed { node, tasks } = seed;
81+
let device = node.processor();
82+
let stream = device.stream();
83+
let weights = Weights::new(model, range, count, &stream);
84+
let mut worker = Worker::new(&node, meta.clone(), weights, i == 0);
85+
let mut cache = meta
86+
.kv_cache(meta.nctx)
87+
.map(|size| stream.malloc::<u8>(size));
88+
let sin_cos = <Operators as llama::Operators>::build_sin_cos(
89+
meta.dt_embd,
90+
meta.nctx,
91+
meta.dh,
92+
&stream,
93+
);
94+
95+
let sample = RandomSample::new(&node);
96+
let indices = RandomSample::build_indices(model.meta.nvoc, &stream);
97+
let mut pair = KVPair::new(0, f16::ZERO);
98+
let mut pairs = Tensor::kv_pair_vec(1, |size| stream.malloc::<u8>(size));
99+
100+
for task in tasks {
101+
let Task {
102+
nt,
97103
pos,
98-
}],
99-
num_tokens: input.len(),
100-
max_seq_len: input.len(),
101-
max_att_len: pos + input.len(),
102-
},
103-
&mut [],
104-
&stream,
105-
)
106-
.unwrap();
107-
108-
let mut pairs = Tensor::kv_pair_vec(1, |size| stream.malloc::<u8>(size));
109-
110-
sample
111-
.launch(&mut pairs, &logits, &indices, sample_args, &mut [], &stream)
112-
.unwrap();
113-
114-
let mut pair = KVPair::new(0, f16::ZERO);
115-
device.memcpy_d2h(
116-
unsafe { from_raw_parts_mut(&mut pair as *mut _ as *mut u8, size_of_val(&pair)) },
117-
pairs.get(),
118-
);
119-
120-
pair.idx() as _
121-
});
122-
});
104+
embd,
105+
next,
106+
} = task;
107+
let mut embd = meta
108+
.embd(nt)
109+
.map(|size| stream.from_host(unsafe { from_raw_parts(embd, size) }));
110+
let mut logits = meta
111+
.logits(if i == 0 { 1 } else { 0 })
112+
.map(|size| stream.malloc::<u8>(size));
113+
worker
114+
.launch(
115+
llama::LlamaArgs {
116+
embd: embd.map_slice_mut(),
117+
logits: logits.map_slice_mut(),
118+
sin_cos: sin_cos.map_slice(),
119+
requests: vec![LlamaRequest {
120+
cache: cache.map_slice_mut(),
121+
seq_len: nt,
122+
out_len: if i == 0 { 1 } else { 0 },
123+
pos,
124+
}],
125+
num_tokens: nt,
126+
max_seq_len: nt,
127+
max_att_len: nt + pos,
128+
},
129+
&mut [],
130+
&stream,
131+
)
132+
.unwrap();
133+
if i == 0 {
134+
sample
135+
.launch(
136+
&mut pairs,
137+
&logits,
138+
&indices,
139+
sample_args,
140+
&mut [],
141+
&stream,
142+
)
143+
.unwrap();
144+
145+
stream.synchronize();
146+
device.memcpy_d2h(
147+
unsafe {
148+
from_raw_parts_mut(
149+
&mut pair as *mut _ as *mut u8,
150+
pairs.get().len(),
151+
)
152+
},
153+
pairs.get(),
154+
);
155+
156+
next.send(pair.idx() as _).unwrap()
157+
}
158+
}
159+
}))
160+
})
161+
.collect::<Vec<_>>();
162+
163+
let senders = senders.into_boxed_slice();
164+
test_infer_paralle(&model, senders, eos, tokenizer, &prompt, max_steps)
165+
})
123166
}

models/llama/nvidia-gpu/src/nccl_parallel.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ fn test_infer() {
5656
.map(|c| c.as_str().parse().unwrap())
5757
.collect()
5858
})
59-
.unwrap_or_else(|| vec![1]);
59+
.unwrap_or_else(|| vec![0]);
6060
let lens = vec![1; devices.len()];
6161
let count = devices.len();
6262
println!("distribution: {devices:?}");

0 commit comments

Comments
 (0)