Skip to content

Commit a7f7b48

Browse files
committed
fix(gpt2): 重构gpt2的cpu的单机推理
1 parent 0146337 commit a7f7b48

File tree

8 files changed

+46
-77
lines changed

8 files changed

+46
-77
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ itertools = "0.13"
3535
env_logger = "0.11"
3636
build-script-cfg = "0.0"
3737

38-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "7886d54", default-features = false }
38+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "359b86a", default-features = false }
3939

4040
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "f69b160" }
4141
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "e8362c3" }

models/gpt2/common-cpu/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ common.workspace = true
1212
operators = { workspace = true, features = ["common-cpu"] }
1313

1414
[dev-dependencies]
15-
test-utils = { workspace = true, features = ["llama"] }
15+
test-utils = { workspace = true, features = ["gpt2"] }
1616
gguf.workspace = true
1717
regex.workspace = true
1818

models/gpt2/common-cpu/src/lib.rs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,6 @@ pub struct Weights<'w> {
2222
output_norm_b: &'w [u8],
2323
output: &'w [u8],
2424
pos_embd: &'w [u8],
25-
// dt_embd: DigitLayout,
26-
// dt_mat: DigitLayout,
27-
// size_qkv_b: usize,
28-
// size_qkv_w: usize,
29-
// size_o_b: usize,
30-
// size_o_w: usize,
31-
// size_up_b: usize,
32-
// size_up_w: usize,
33-
// size_down_b: usize,
34-
// size_down_w: usize,
3525
}
3626

3727
macro_rules! op {

models/gpt2/common/src/storage.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,4 +291,4 @@ impl Gpt2Meta {
291291
}
292292
own(ans.take())
293293
}
294-
}
294+
}

models/gpt2/cuda/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ search-cuda-tools.workspace = true
1818
search-corex-tools.workspace = true
1919

2020
[dev-dependencies]
21-
test-utils.workspace = true
21+
test-utils = { workspace = true, features = ["gpt2"] }
2222
gguf.workspace = true
2323
regex.workspace = true

models/llama/common-cpu/src/infer.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ type Worker<'w> = LlamaWorker<Operators<InprocNode<usize>, AllReduce>, Weights<'
2222

2323
#[test]
2424
fn test_infer() {
25-
std::env::set_var("TEST_MODEL", r"F:\TinyLlama-1.1B-Chat-v1.0-F16.gguf");
2625
let Some(Inference {
2726
model,
2827
devices,

models/llama/opencl/src/lib.rs

Lines changed: 36 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#![cfg(detected)]
22

3-
use common::{Distribution, WeightMemCalculator};
3+
use common::Distribution;
44
use llama::{LlamaBlkStorage, LlamaBlkWeight, LlamaStorage, Tensor, WeightLoader};
55
use operators::{
66
all_reduce::{AllReduce, NonAllReduce},
@@ -10,12 +10,7 @@ use operators::{
1010
rearrange::opencl::Operator as Rearrange,
1111
Blob, ByteOf, QueueOf, TopoNode,
1212
};
13-
use std::{
14-
iter::zip,
15-
marker::PhantomData,
16-
ops::{Deref, Range},
17-
ptr::copy_nonoverlapping,
18-
};
13+
use std::{marker::PhantomData, ops::Deref, ptr::copy_nonoverlapping};
1914

2015
pub struct Operators<N = ClDevice, R = NonAllReduce<ClDevice, Rearrange>>(PhantomData<(N, R)>);
2116

@@ -65,10 +60,9 @@ where
6560

6661
pub struct Weights {
6762
nexp: usize,
68-
mem: SvmBlob,
69-
blks: Box<[LlamaBlkStorage<Range<usize>>]>,
70-
output_norm: Range<usize>,
71-
output: Range<usize>,
63+
blks: Box<[LlamaBlkStorage<SvmBlob>]>,
64+
output_norm: SvmBlob,
65+
output: SvmBlob,
7266
}
7367

7468
impl Weights {
@@ -81,52 +75,40 @@ impl Weights {
8175
..
8276
} = model;
8377

84-
let mut calculator = WeightMemCalculator::new(size_of::<usize>());
85-
let meta_dist = meta.distribute(dist);
86-
let blk_size = meta_dist.blk();
87-
let off_blks = (0..meta_dist.nblk)
88-
.map(|_| {
89-
blk_size
90-
.clone()
78+
let meta = meta.distribute(dist);
79+
let queue = ctx.queue();
80+
let blks = blocks
81+
.iter()
82+
.map(|blk| {
83+
blk.clone()
9184
.into_vec()
9285
.into_iter()
93-
.map(|(which, size)| (which, calculator.push(size)))
86+
.map(|(which, data)| {
87+
let blob = meta.distribute_data(which, data, dist, Blob::new);
88+
let mut svm = ctx.malloc::<u8>(blob.len());
89+
let mut map = queue.map_mut(&mut svm, false);
90+
map.copy_from_slice(&blob);
91+
queue.unmap(map);
92+
(which, svm)
93+
})
9494
.collect::<LlamaBlkStorage<_>>()
9595
})
9696
.collect::<Vec<_>>();
97-
let off_output_norm = calculator.push(output_norm.len());
98-
let off_output = calculator.push(output.len());
9997

100-
let mut mem = ctx.malloc::<u8>(calculator.size());
101-
let queue = ctx.queue();
102-
103-
for (blk, off) in zip(blocks, off_blks.clone()) {
104-
let blk = blk.clone().into_vec();
105-
let off = off.into_vec();
106-
for ((which, data), (which_, off)) in zip(blk, off) {
107-
assert_eq!(which, which_);
108-
if off.is_empty() {
109-
continue;
110-
}
111-
let data = meta.distribute_data(which, data, dist, Blob::new);
112-
let mut map = queue.map_mut(&mut mem[off], false);
113-
map.copy_from_slice(&data);
114-
queue.unmap(map)
115-
}
116-
}
117-
let mut map = queue.map_mut(&mut mem[off_output_norm.clone()], false);
118-
map.copy_from_slice(output_norm);
119-
queue.unmap(map);
120-
let mut map = queue.map_mut(&mut mem[off_output.clone()], false);
121-
map.copy_from_slice(output);
122-
queue.unmap(map);
98+
let mut output_norm_svm = ctx.malloc::<u8>(output_norm.len());
99+
let mut output_svm = ctx.malloc::<u8>(output.len());
100+
let mut output_norm_map = queue.map_mut(&mut output_norm_svm, false);
101+
let mut output_map = queue.map_mut(&mut output_svm, false);
102+
output_norm_map.copy_from_slice(output_norm);
103+
output_map.copy_from_slice(output);
104+
queue.unmap(output_norm_map);
105+
queue.unmap(output_map);
123106

124107
Self {
125108
nexp: meta.nexp,
126-
mem,
127-
blks: off_blks.into_boxed_slice(),
128-
output_norm: off_output_norm,
129-
output: off_output,
109+
blks: blks.into_boxed_slice(),
110+
output_norm: output_norm_svm,
111+
output: output_svm,
130112
}
131113
}
132114
}
@@ -158,7 +140,7 @@ impl WeightLoader for Weights {
158140

159141
use LlamaBlkWeight as W;
160142
#[rustfmt::skip]
161-
let range = match which {
143+
let ans = match which {
162144
W::AttnNorm => attn_norm ,
163145
W::AttnQKV => attn_qkv ,
164146
W::AttnQKVBias => attn_qkv_bias,
@@ -168,7 +150,7 @@ impl WeightLoader for Weights {
168150
W::FfnGateUp => ffn_gate_up ,
169151
W::FfnDown => ffn_down ,
170152
};
171-
&self.mem[range.clone()]
153+
ans
172154
}
173155

174156
fn load_moe<'a>(
@@ -184,26 +166,25 @@ impl WeightLoader for Weights {
184166
..
185167
} = &self.blks[iblk];
186168

187-
let range = match which {
169+
let w = match which {
188170
LlamaBlkWeight::FfnGateUp => ffn_gate_up,
189171
LlamaBlkWeight::FfnDown => ffn_down,
190172
_ => unreachable!(),
191173
};
192-
let w = &self.mem[range.clone()];
193174
let one = w.len() / self.nexp;
194175
&w[iexp * one..][..one]
195176
}
196177

197178
#[inline]
198179
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> {
199-
&self.mem[self.output_norm.clone()]
180+
&self.output_norm
200181
}
201182

202183
#[inline]
203184
fn output(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> {
204-
&self.mem[self.output.clone()]
185+
&self.output
205186
}
206187
}
207188

208189
#[cfg(test)]
209-
mod infer;
190+
mod infer;

test-utils/src/lib.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@ use gguf::{
55
use std::{
66
env::{var, var_os},
77
fmt,
8+
iter::zip,
89
path::{Path, PathBuf},
910
str::FromStr,
10-
sync::{Once,mpsc},
11-
iter::zip,
11+
sync::{mpsc, Once},
1212
time::{Duration, Instant},
1313
};
1414
#[cfg(feature = "llama")]
15-
mod llama{
15+
mod llama {
16+
use crate::InferStorage;
1617
use llama::LlamaStorage;
1718
use tensor::Tensor;
18-
use crate::InferStorage;
1919

2020
impl InferStorage for &LlamaStorage<&[u8]> {
2121
fn embd(&self, nt: usize) -> Tensor<usize> {
@@ -27,10 +27,10 @@ mod llama{
2727
}
2828
}
2929
#[cfg(feature = "llama")]
30-
mod gpt2{
30+
mod gpt2 {
31+
use crate::InferStorage;
3132
use gpt2::GPT2Storage;
3233
use tensor::Tensor;
33-
use crate::InferStorage;
3434

3535
impl InferStorage for &GPT2Storage<&[u8]> {
3636
fn embd(&self, nt: usize) -> Tensor<usize> {
@@ -42,7 +42,6 @@ mod gpt2{
4242
}
4343
}
4444

45-
4645
pub trait InferStorage {
4746
fn embd(&self, nt: usize) -> Tensor<usize>;
4847
fn token_embd(&self) -> &[u8];

0 commit comments

Comments
 (0)