Skip to content

Commit d67f60f

Browse files
authored
Merge pull request #29 from onenewcode/dev
feat(gpt2):实现gpt2模型cuda版本单机推理
2 parents e73b0fb + 47beb03 commit d67f60f

File tree

14 files changed

+880
-258
lines changed

14 files changed

+880
-258
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ members = [
1616

1717
"models/gpt2/common",
1818
"models/gpt2/common-cpu",
19+
"models/gpt2/cuda",
1920
]
2021
resolver = "2"
2122

models/gpt2/common-cpu/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ authors = ["onenewcode <[email protected]>", "YdrMaster <[email protected]>"]
88

99
[dependencies]
1010
gpt2.path = "../common"
11+
common.workspace = true
1112
operators = { workspace = true, features = ["common-cpu"] }
1213

1314
[dev-dependencies]
14-
test-utils.workspace = true
15+
test-utils = { workspace = true, features = ["gpt2"] }
1516
gguf.workspace = true
16-
ndarray-layout.workspace = true
17+
regex.workspace = true
Lines changed: 107 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,35 @@
11
use crate::{Operators, RandomSample, Weights};
2+
use common::Distribution;
23
use gguf::GGufModel;
3-
use gpt2::{ext::ggml_quants::f16, Gpt2Meta, Gpt2Worker, Storage, Tensor};
4+
use gpt2::{ext::ggml_quants::f16, GPT2Storage, Gpt2Worker, Tensor};
45
use operators::{
5-
common_cpu::{Cpu, ThisThread},
6+
all_reduce::common_cpu::Operator as AllReduce,
7+
common_cpu::{InprocNode, ThisThread},
68
random_sample::{KVPair, SampleArgs},
79
Blob,
810
};
9-
use std::slice::from_raw_parts_mut;
10-
use test_utils::{Inference, TokenizerAndPrompt};
11-
12-
type Worker<'w> = Gpt2Worker<Operators, Weights<'w>>;
11+
use regex::Regex;
12+
use std::{
13+
iter::zip,
14+
ptr::copy_nonoverlapping,
15+
slice::from_raw_parts_mut,
16+
sync::{Arc, Barrier},
17+
thread,
18+
};
19+
use test_utils::{test_infer_paralle, Inference, Task, TokenizerAndPrompt, WorkerSeed};
1320

21+
type Worker<'w> = Gpt2Worker<Operators<InprocNode<usize>, AllReduce>, Weights<'w>>;
1422
#[test]
1523
fn test_infer() {
1624
let Some(Inference {
1725
model,
26+
devices,
1827
prompt,
1928
as_user,
2029
temperature,
2130
top_p,
2231
top_k,
2332
max_steps,
24-
..
2533
}) = Inference::load()
2634
else {
2735
return;
@@ -34,73 +42,104 @@ fn test_infer() {
3442
prompt,
3543
} = TokenizerAndPrompt::new(&gguf, prompt, as_user);
3644

37-
let model = Storage::from_gguf(&gguf);
45+
let model = GPT2Storage::from_gguf(&gguf);
3846
println!("{:?}", model.meta);
3947

4048
let sample_args = SampleArgs::new(temperature, top_p, top_k).expect("invalid sample args");
4149
println!("{sample_args:?}");
4250

43-
let &Gpt2Meta {
44-
dt_embd,
45-
nctx,
46-
nvoc,
47-
d,
48-
..
49-
} = &model.meta;
50-
let weights = Weights::new(&model);
51-
let mut worker = Worker::new(&Cpu, model.meta.clone(), weights);
52-
let mut cache = model.meta.kv_cache(nctx).map(Blob::new);
53-
let indices = RandomSample::build_indices(nvoc, &ThisThread);
54-
let sample = RandomSample::new(&Cpu);
51+
let lens = devices
52+
.map(|devices| {
53+
Regex::new(r"\d+")
54+
.unwrap()
55+
.find_iter(&devices)
56+
.map(|c| c.as_str().parse().unwrap())
57+
.collect()
58+
})
59+
.unwrap_or_else(|| vec![1]);
60+
let dist = lens.iter().sum();
61+
println!("distribution: {lens:?}");
62+
63+
let (seeds, senders) = WorkerSeed::new(InprocNode::new(lens.len()));
64+
let barrier = Arc::new(Barrier::new(dist + 1));
65+
thread::scope(|s| {
66+
let _workers = zip(lens, seeds)
67+
.enumerate()
68+
.scan(0, |start, (id, (len, seed))| {
69+
let dist = Distribution::new(*start, len, dist);
70+
*start += len;
5571

56-
test_utils::test_infer(eos, tokenizer, &prompt, max_steps, |input, pos| {
57-
// 词汇编码缓存
58-
let mut embd = Tensor::new(dt_embd, &[input.len(), d]).map(Blob::new);
59-
// 词汇位置缓存
60-
let mut logits = model.meta.logits(1).map(Blob::new);
61-
let l = embd.get().len() / input.len();
62-
for (i, &tok) in input.iter().enumerate() {
63-
embd.get_mut()[i * l..][..l]
64-
.copy_from_slice(&model.token_embd[tok as usize * l..][..l]);
65-
}
66-
worker
67-
.launch(
68-
gpt2::args::Args {
69-
embd: embd.map_slice_mut(),
70-
logits: logits.map_slice_mut(),
71-
idx: postion(input.len(), pos).map_slice(),
72-
requests: vec![gpt2::args::Request {
73-
cache: cache.map_slice_mut(),
74-
seq_len: input.len(),
75-
out_len: 1,
76-
pos,
77-
}],
78-
max_seq_len: input.len(),
79-
max_att_len: pos + input.len(),
80-
},
81-
&mut [],
82-
&ThisThread,
83-
)
84-
.unwrap();
72+
let meta = model.meta.distribute(dist);
73+
let model = &model;
74+
let barrier = barrier.clone();
75+
Some(s.spawn(move || {
76+
let WorkerSeed { node, tasks } = seed;
77+
let weights = Weights::new(model, dist);
78+
let mut worker = Worker::new(id, &node, meta.clone(), weights);
79+
let mut cache = meta.kv_cache(meta.nctx).map(Blob::new);
8580

86-
let mut pair = KVPair::new(0, f16::ZERO);
87-
let mut pairs = Tensor::kv_pair_vec(1, |_| unsafe {
88-
from_raw_parts_mut(&mut pair as *mut _ as _, size_of_val(&pair))
89-
});
81+
let sample = RandomSample::new(&node);
82+
let indices = RandomSample::build_indices(model.meta.nvoc, &ThisThread);
83+
let mut pair = KVPair::new(0, f16::ZERO);
84+
let mut pairs = Tensor::kv_pair_vec(1, |_| unsafe {
85+
from_raw_parts_mut(&mut pair as *mut _ as *mut u8, size_of_val(&pair))
86+
});
9087

91-
sample
92-
.launch(
93-
&mut pairs,
94-
&logits,
95-
&indices,
96-
sample_args,
97-
&mut [],
98-
&ThisThread,
99-
)
100-
.unwrap();
88+
barrier.wait();
89+
for task in tasks {
90+
let Task {
91+
nt,
92+
pos,
93+
embd,
94+
next,
95+
} = task;
96+
let mut embd = meta.embd(nt).map(|size| {
97+
let mut blob = Blob::new(size);
98+
unsafe { copy_nonoverlapping(embd, blob.as_mut_ptr(), size) };
99+
blob
100+
});
101+
let mut logits = meta.logits(if id == 0 { 1 } else { 0 }).map(Blob::new);
102+
worker
103+
.launch(
104+
gpt2::args::Args {
105+
embd: embd.map_slice_mut(),
106+
logits: logits.map_slice_mut(),
107+
idx: postion(nt, pos).map_slice(),
108+
requests: vec![gpt2::args::Request {
109+
cache: cache.map_slice_mut(),
110+
seq_len: nt,
111+
out_len: 1,
112+
pos,
113+
}],
114+
max_seq_len: nt,
115+
max_att_len: pos + nt,
116+
},
117+
&mut [],
118+
&ThisThread,
119+
)
120+
.unwrap();
121+
if id == 0 {
122+
sample
123+
.launch(
124+
&mut pairs,
125+
&logits,
126+
&indices,
127+
sample_args,
128+
&mut [],
129+
&ThisThread,
130+
)
131+
.unwrap();
132+
next.send(pair.idx() as _).unwrap()
133+
}
134+
}
135+
}))
136+
})
137+
.collect::<Vec<_>>();
101138

102-
pair.idx() as _
103-
});
139+
let senders = senders.into_boxed_slice();
140+
barrier.wait();
141+
test_infer_paralle(&model, senders, eos, tokenizer, &prompt, max_steps)
142+
})
104143
}
105144

106145
fn postion(l: usize, pos: usize) -> Tensor<Blob> {
@@ -109,8 +148,8 @@ fn postion(l: usize, pos: usize) -> Tensor<Blob> {
109148
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<u32>() }) else {
110149
panic!()
111150
};
112-
for i in 0..l {
113-
data[i] = (pos + i) as u32;
114-
}
151+
data.iter_mut()
152+
.enumerate()
153+
.for_each(|(i, item)| *item = (pos + i) as u32);
115154
ans
116155
}

models/gpt2/common-cpu/src/lib.rs

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
1-
use gpt2::{
2-
storage::{BlkStorage, Storage},
3-
BlkWeight, Tensor, WeightLoader,
4-
};
1+
use common::{Contiguous, Distribution};
2+
use gpt2::{storage::BlkStorage, BlkWeight, Tensor, WeightLoader};
53
use operators::{
64
all_reduce::{AllReduce, NonAllReduce},
75
common_cpu::Cpu,
86
random_sample::common_cpu::Operator as RandomSampleCpu,
97
rearrange::common_cpu::Operator as Rearrange,
10-
ByteOf, QueueOf, TopoNode,
8+
Blob, ByteOf, QueueOf, TopoNode,
119
};
12-
use std::marker::PhantomData;
13-
use std::ops::Deref;
10+
use std::{marker::PhantomData, ops::Deref};
11+
12+
pub use gpt2::{GPT2BlkStorage, GPT2Storage, TensorUsage::Computation};
1413

1514
pub struct Operators<N = Cpu, R = NonAllReduce<Cpu, Rearrange>>(PhantomData<(N, R)>);
1615

1716
pub type RandomSample = gpt2::RandomSample<Cpu, RandomSampleCpu>;
1817

1918
pub struct Weights<'w> {
20-
blks: Box<[BlkStorage<&'w [u8]>]>,
19+
blks: Box<[GPT2BlkStorage<Contiguous<'w, Blob>>]>,
2120
output_norm_w: &'w [u8],
2221
output_norm_b: &'w [u8],
2322
output: &'w [u8],
@@ -55,26 +54,37 @@ where
5554
}
5655

5756
impl<'w> Weights<'w> {
58-
pub fn new(model: &'w Storage<&'w [u8]>) -> Self {
59-
let Storage {
60-
output_norm_w,
61-
output_norm_b,
57+
pub fn new(model: &'w GPT2Storage<&'w [u8]>, dist: Distribution) -> Self {
58+
let GPT2Storage {
59+
meta,
6260
output,
6361
blocks,
6462
pos_embd,
63+
output_norm_b,
64+
output_norm_w,
6565
..
6666
} = model;
6767

68+
let blks = blocks
69+
.iter()
70+
.map(|blk| {
71+
blk.into_vec()
72+
.into_iter()
73+
.map(|(which, data)| {
74+
(which, meta.distribute_data(which, data, dist, Blob::new))
75+
})
76+
.collect::<GPT2BlkStorage<_>>()
77+
})
78+
.collect::<Box<_>>();
6879
Self {
69-
pos_embd,
70-
blks: blocks.clone(),
71-
output_norm_w,
80+
blks,
7281
output_norm_b,
82+
output_norm_w,
7383
output,
84+
pos_embd,
7485
}
7586
}
7687
}
77-
7888
impl WeightLoader for Weights<'_> {
7989
type Hardware = Cpu;
8090
type Memory<'s>
@@ -103,7 +113,6 @@ impl WeightLoader for Weights<'_> {
103113
ffn_down_w,
104114
ffn_down_b,
105115
} = &self.blks[iblk];
106-
107116
match which {
108117
BlkWeight::AttnNorm => [attn_norm_w, attn_norm_b],
109118
BlkWeight::AttnQKV => [attn_qkv_w, attn_qkv_b],
@@ -113,6 +122,7 @@ impl WeightLoader for Weights<'_> {
113122
BlkWeight::FfnDown => [ffn_down_w, ffn_down_b],
114123
}
115124
}
125+
116126
#[inline]
117127
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> [Self::Memory<'_>; 2] {
118128
[self.output_norm_w, self.output_norm_b]

models/gpt2/common/src/compute.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ pub trait WeightLoader {
6060
}
6161

6262
pub struct Gpt2Worker<Ops: Operators, W> {
63+
#[allow(dead_code)]
64+
id: usize,
6365
meta: Gpt2Meta,
6466
weights: WeightDecorator<W>,
6567
add_rows: Ops::AddRows,
@@ -70,14 +72,14 @@ pub struct Gpt2Worker<Ops: Operators, W> {
7072
add: Ops::Add,
7173
rearrange: Ops::Rearrange,
7274
all_reduce: Ops::AllReduce,
73-
pub debug: bool,
7475
}
7576

7677
impl<Ops: Operators, W> Gpt2Worker<Ops, W> {
77-
pub fn new(node: &Ops::TopoNode, meta: Gpt2Meta, weights: W) -> Self {
78+
pub fn new(id: usize, node: &Ops::TopoNode, meta: Gpt2Meta, weights: W) -> Self {
7879
let processor = node.processor();
7980
Self {
80-
weights: meta.decorator(weights), // meta.decorator
81+
id,
82+
weights: meta.decorator(weights),
8183
meta,
8284
add_rows: Ops::AddRows::new(processor),
8385
layer_norm: Ops::LayerNorm::new(processor),
@@ -87,7 +89,6 @@ impl<Ops: Operators, W> Gpt2Worker<Ops, W> {
8789
add: Ops::Add::new(processor),
8890
rearrange: Ops::Rearrange::new(processor),
8991
all_reduce: Ops::AllReduce::new(node),
90-
debug: true,
9192
}
9293
}
9394

0 commit comments

Comments
 (0)