Skip to content

Commit 0146337

Browse files
committed
fix(gpt2): 重构gpt2的cpu的单机推理
1 parent 1842203 commit 0146337

File tree

15 files changed

+352
-320
lines changed

15 files changed

+352
-320
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ itertools = "0.13"
3535
env_logger = "0.11"
3636
build-script-cfg = "0.0"
3737

38-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "01b1667", default-features = false }
38+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "7886d54", default-features = false }
3939

4040
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "f69b160" }
4141
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "e8362c3" }

common/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ version = "0.0.0"
44
edition = "2021"
55
authors = ["YdrMaster <[email protected]>"]
66

7-
[dependencies]
7+
[dependencies]

models/gpt2/common-cpu/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ authors = ["onenewcode <[email protected]>", "YdrMaster <[email protected]>"]
88

99
[dependencies]
1010
gpt2.path = "../common"
11+
common.workspace = true
1112
operators = { workspace = true, features = ["common-cpu"] }
1213

1314
[dev-dependencies]
14-
test-utils.workspace = true
15+
test-utils = { workspace = true, features = ["llama"] }
1516
gguf.workspace = true
16-
ndarray-layout.workspace = true
17+
regex.workspace = true
18+
Lines changed: 136 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,35 @@
11
use crate::{Operators, RandomSample, Weights};
2-
use gguf::GGufModel;
3-
use gpt2::{ext::ggml_quants::f16, GPT2Storage, Gpt2Meta, Gpt2Worker, Tensor};
2+
use common::Distribution;
3+
use gguf::ext::utok;
4+
use gguf::{GGufModel, Tokenizer};
5+
use gpt2::{ext::ggml_quants::f16, GPT2Storage, Gpt2Worker, Tensor};
6+
use operators::common_cpu::InprocNode;
47
use operators::{
5-
common_cpu::{Cpu, ThisThread},
8+
all_reduce::common_cpu::Operator as AllReduce,
9+
common_cpu::ThisThread,
610
random_sample::{KVPair, SampleArgs},
711
Blob,
812
};
13+
use regex::Regex;
14+
use std::iter::zip;
15+
use std::ptr::copy_nonoverlapping;
916
use std::slice::from_raw_parts_mut;
10-
use test_utils::{Inference, TokenizerAndPrompt};
11-
12-
type Worker<'w> = Gpt2Worker<Operators, Weights<'w>>;
17+
use std::sync::{mpsc, Arc, Barrier};
18+
use std::thread;
19+
use test_utils::{Inference, Task, TokenizerAndPrompt, WorkerSeed};
1320

21+
type Worker<'w> = Gpt2Worker<Operators<InprocNode<usize>, AllReduce>, Weights<'w>>;
1422
#[test]
1523
fn test_infer() {
1624
let Some(Inference {
1725
model,
26+
devices,
1827
prompt,
1928
as_user,
2029
temperature,
2130
top_p,
2231
top_k,
2332
max_steps,
24-
..
2533
}) = Inference::load()
2634
else {
2735
return;
@@ -40,77 +48,138 @@ fn test_infer() {
4048
let sample_args = SampleArgs::new(temperature, top_p, top_k).expect("invalid sample args");
4149
println!("{sample_args:?}");
4250

43-
let &Gpt2Meta {
44-
dt_embd,
45-
nctx,
46-
nvoc,
47-
d,
48-
..
49-
} = &model.meta;
50-
let weights = Weights::new(&model);
51-
let mut worker = Worker::new(0, &Cpu, model.meta.clone(), weights);
52-
let mut cache = model.meta.kv_cache(nctx).map(Blob::new);
53-
let indices = RandomSample::build_indices(nvoc, &ThisThread);
54-
let sample = RandomSample::new(&Cpu);
51+
let lens = devices
52+
.map(|devices| {
53+
Regex::new(r"\d+")
54+
.unwrap()
55+
.find_iter(&devices)
56+
.map(|c| c.as_str().parse().unwrap())
57+
.collect()
58+
})
59+
.unwrap_or_else(|| vec![1]);
60+
let dist = lens.iter().sum();
61+
println!("distribution: {lens:?}");
5562

56-
test_utils::test_infer(eos, tokenizer, &prompt, max_steps, |input, pos| {
57-
// 词汇编码缓存
58-
let mut embd = Tensor::new(dt_embd, &[input.len(), d]).map(Blob::new);
59-
// 词汇位置缓存
60-
let mut logits = model.meta.logits(1).map(Blob::new);
61-
let l = embd.get().len() / input.len();
62-
for (i, &tok) in input.iter().enumerate() {
63-
embd.get_mut()[i * l..][..l]
64-
.copy_from_slice(&model.token_embd[tok as usize * l..][..l]);
65-
}
66-
worker
67-
.launch(
68-
gpt2::args::Args {
69-
embd: embd.map_slice_mut(),
70-
logits: logits.map_slice_mut(),
71-
idx: postion(input.len(), pos).map_slice(),
72-
requests: vec![gpt2::args::Request {
73-
cache: cache.map_slice_mut(),
74-
seq_len: input.len(),
75-
out_len: 1,
76-
pos,
77-
}],
78-
max_seq_len: input.len(),
79-
max_att_len: pos + input.len(),
80-
},
81-
&mut [],
82-
&ThisThread,
83-
)
84-
.unwrap();
63+
let (seeds, senders) = WorkerSeed::new(InprocNode::new(lens.len()));
64+
let barrier = Arc::new(Barrier::new(dist + 1));
65+
thread::scope(|s| {
66+
let _workers = zip(lens, seeds)
67+
.enumerate()
68+
.scan(0, |start, (id, (len, seed))| {
69+
let dist = Distribution::new(*start, len, dist);
70+
*start += len;
8571

86-
let mut pair = KVPair::new(0, f16::ZERO);
87-
let mut pairs = Tensor::kv_pair_vec(1, |_| unsafe {
88-
from_raw_parts_mut(&mut pair as *mut _ as _, size_of_val(&pair))
89-
});
72+
let meta = model.meta.distribute(dist);
73+
let model = &model;
74+
let barrier = barrier.clone();
75+
Some(s.spawn(move || {
76+
let WorkerSeed { node, tasks } = seed;
77+
let weights = Weights::new(model, dist);
78+
let mut worker = Worker::new(id, &node, meta.clone(), weights);
79+
let mut cache = meta.kv_cache(meta.nctx).map(Blob::new);
9080

91-
sample
92-
.launch(
93-
&mut pairs,
94-
&logits,
95-
&indices,
96-
sample_args,
97-
&mut [],
98-
&ThisThread,
99-
)
100-
.unwrap();
81+
let sample = RandomSample::new(&node);
82+
let indices = RandomSample::build_indices(model.meta.nvoc, &ThisThread);
83+
let mut pair = KVPair::new(0, f16::ZERO);
84+
let mut pairs = Tensor::kv_pair_vec(1, |_| unsafe {
85+
from_raw_parts_mut(&mut pair as *mut _ as *mut u8, size_of_val(&pair))
86+
});
10187

102-
pair.idx() as _
103-
});
88+
barrier.wait();
89+
for task in tasks {
90+
let Task {
91+
nt,
92+
pos,
93+
embd,
94+
next,
95+
} = task;
96+
let mut embd = meta.embd(nt).map(|size| {
97+
let mut blob = Blob::new(size);
98+
unsafe { copy_nonoverlapping(embd, blob.as_mut_ptr(), size) };
99+
blob
100+
});
101+
let mut logits = meta.logits(if id == 0 { 1 } else { 0 }).map(Blob::new);
102+
worker
103+
.launch(
104+
gpt2::args::Args {
105+
embd: embd.map_slice_mut(),
106+
logits: logits.map_slice_mut(),
107+
idx: postion(nt, pos).map_slice(),
108+
requests: vec![gpt2::args::Request {
109+
cache: cache.map_slice_mut(),
110+
seq_len: nt,
111+
out_len: 1,
112+
pos,
113+
}],
114+
max_seq_len: nt,
115+
max_att_len: pos + nt,
116+
},
117+
&mut [],
118+
&ThisThread,
119+
)
120+
.unwrap();
121+
if id == 0 {
122+
sample
123+
.launch(
124+
&mut pairs,
125+
&logits,
126+
&indices,
127+
sample_args,
128+
&mut [],
129+
&ThisThread,
130+
)
131+
.unwrap();
132+
next.send(pair.idx() as _).unwrap()
133+
}
134+
}
135+
}))
136+
})
137+
.collect::<Vec<_>>();
138+
139+
let senders = senders.into_boxed_slice();
140+
barrier.wait();
141+
test_infer_par(&model, senders, eos, tokenizer, &prompt, max_steps)
142+
})
104143
}
105144

145+
pub fn test_infer_par(
146+
model: &GPT2Storage<&[u8]>,
147+
senders: Box<[mpsc::Sender<Task>]>,
148+
eos: utok,
149+
tokenizer: Tokenizer,
150+
prompt: &str,
151+
max_steps: usize,
152+
) {
153+
let (next, next_recv) = mpsc::channel();
154+
test_utils::test_infer(eos, tokenizer, prompt, max_steps, |input, pos| {
155+
let mut embd = model.meta.embd(input.len()).map(Blob::new).take();
156+
157+
let d = embd.len() / input.len();
158+
for (i, &tok) in input.iter().enumerate() {
159+
embd[i * d..][..d].copy_from_slice(&model.token_embd[tok as usize * d..][..d]);
160+
}
161+
162+
for sender in &senders {
163+
sender
164+
.send(Task {
165+
nt: input.len(),
166+
pos,
167+
embd: embd.as_ptr(),
168+
next: next.clone(),
169+
})
170+
.unwrap()
171+
}
172+
next_recv.recv().unwrap()
173+
});
174+
}
106175
fn postion(l: usize, pos: usize) -> Tensor<Blob> {
107176
use gguf::ggml_quants::digit_layout::types as ty;
108177
let mut ans = Tensor::new(ty::U32, &[1, l]).map(Blob::new);
109178
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<u32>() }) else {
110179
panic!()
111180
};
112-
for i in 0..l {
113-
data[i] = (pos + i) as u32;
114-
}
181+
data.iter_mut()
182+
.enumerate()
183+
.for_each(|(i, item)| *item = (pos + i) as u32);
115184
ans
116185
}

models/gpt2/common-cpu/src/lib.rs

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,37 @@
1-
use gpt2::{
2-
storage::{BlkStorage, Storage},
3-
BlkWeight, Tensor, WeightLoader,
4-
};
5-
pub use gpt2::{GPT2BlkStorage, GPT2Storage};
1+
use common::{Contiguous, Distribution};
2+
use gpt2::{storage::BlkStorage, BlkWeight, Tensor, WeightLoader};
63
use operators::{
74
all_reduce::{AllReduce, NonAllReduce},
85
common_cpu::Cpu,
96
random_sample::common_cpu::Operator as RandomSampleCpu,
107
rearrange::common_cpu::Operator as Rearrange,
11-
ByteOf, QueueOf, TopoNode,
8+
Blob, ByteOf, QueueOf, TopoNode,
129
};
10+
use std::marker::PhantomData;
1311
use std::ops::Deref;
14-
use std::{marker::PhantomData, ptr::copy_nonoverlapping};
12+
13+
pub use gpt2::{GPT2BlkStorage, GPT2Storage, TensorUsage::Computation};
1514

1615
pub struct Operators<N = Cpu, R = NonAllReduce<Cpu, Rearrange>>(PhantomData<(N, R)>);
1716

1817
pub type RandomSample = gpt2::RandomSample<Cpu, RandomSampleCpu>;
1918

2019
pub struct Weights<'w> {
21-
blks: Box<[BlkStorage<&'w [u8]>]>,
20+
blks: Box<[GPT2BlkStorage<Contiguous<'w, Blob>>]>,
2221
output_norm_w: &'w [u8],
2322
output_norm_b: &'w [u8],
2423
output: &'w [u8],
2524
pos_embd: &'w [u8],
25+
// dt_embd: DigitLayout,
26+
// dt_mat: DigitLayout,
27+
// size_qkv_b: usize,
28+
// size_qkv_w: usize,
29+
// size_o_b: usize,
30+
// size_o_w: usize,
31+
// size_up_b: usize,
32+
// size_up_w: usize,
33+
// size_down_b: usize,
34+
// size_down_w: usize,
2635
}
2736

2837
macro_rules! op {
@@ -53,39 +62,60 @@ where
5362
{
5463
println!("{tensor}");
5564
}
56-
57-
fn memcpy_d2h<T: Copy>(
58-
dst: &mut [T],
59-
src: &[ByteOf<Self::Hardware>],
60-
_queue: &QueueOf<Self::Hardware>,
61-
) {
62-
let count = size_of_val(dst);
63-
assert_eq!(size_of_val(src), count);
64-
unsafe { copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr().cast::<u8>(), count) }
65-
}
6665
}
6766

6867
impl<'w> Weights<'w> {
69-
pub fn new(model: &'w Storage<&'w [u8]>) -> Self {
70-
let Storage {
71-
output_norm_w,
72-
output_norm_b,
68+
pub fn new(model: &'w GPT2Storage<&'w [u8]>, dist: Distribution) -> Self {
69+
let GPT2Storage {
70+
meta,
7371
output,
7472
blocks,
7573
pos_embd,
74+
output_norm_b,
75+
output_norm_w,
7676
..
7777
} = model;
7878

79+
let blks = blocks
80+
.iter()
81+
.map(|blk| {
82+
blk.into_vec()
83+
.into_iter()
84+
.map(|(which, data)| {
85+
(which, meta.distribute_data(which, data, dist, Blob::new))
86+
})
87+
.collect::<GPT2BlkStorage<_>>()
88+
})
89+
.collect::<Box<_>>();
90+
91+
// let meta = meta.distribute(dist);
92+
// let size_qkv_w = meta.attn_qkv_w(Computation).take();
93+
// let size_qkv_b = meta.attn_qkv_b(Computation).take();
94+
// let size_o_w = meta.attn_o_w(Computation).take();
95+
// let size_o_b = meta.attn_o_b(Computation).take();
96+
// let size_up_w = meta.ffn_down_w(Computation).take();
97+
// let size_up_b = meta.ffn_down_b(Computation).take();
98+
// let size_down_w = meta.ffn_down_w(Computation).take();
99+
// let size_down_b = meta.ffn_down_b(Computation).take();
79100
Self {
80-
pos_embd,
81-
blks: blocks.clone(),
82-
output_norm_w,
101+
blks,
83102
output_norm_b,
103+
output_norm_w,
84104
output,
105+
pos_embd,
106+
// dt_embd: meta.dt_embd,
107+
// dt_mat: meta.dt_linear,
108+
// size_qkv_b,
109+
// size_qkv_w,
110+
// size_o_b,
111+
// size_o_w,
112+
// size_up_b,
113+
// size_up_w,
114+
// size_down_b,
115+
// size_down_w,
85116
}
86117
}
87118
}
88-
89119
impl WeightLoader for Weights<'_> {
90120
type Hardware = Cpu;
91121
type Memory<'s>

0 commit comments

Comments
 (0)