Skip to content

Commit 04bd8a1

Browse files
committed
perf(nv): 加速大模型加载
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent f16b51a commit 04bd8a1

File tree

12 files changed

+86
-22
lines changed

12 files changed

+86
-22
lines changed

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ itertools = "0.13"
3131
env_logger = "0.11"
3232
build-script-cfg = "0.0"
3333

34-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "44ad48", default-features = false }
34+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "89ffbf1", default-features = false }
3535
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "9b6289d" }
36-
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "0e57976" }
37-
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "5b9dbd9" }
36+
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "f40bcb5" }
37+
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "041badf" }

common/src/lib.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::ops::Deref;
1+
use std::{borrow::Borrow, collections::HashMap, hash::Hash, ops::Deref};
22

33
pub enum Contiguous<'a, T> {
44
Borrowed(&'a [u8]),
@@ -25,3 +25,30 @@ pub fn borrow<T>(t: &[u8]) -> Contiguous<'_, T> {
2525
pub fn own<'a, T>(t: T) -> Contiguous<'a, T> {
2626
Contiguous::Owned(t)
2727
}
28+
29+
#[derive(Clone, Default, Debug)]
30+
#[repr(transparent)]
31+
pub struct Slab<K, V>(HashMap<K, Vec<V>>);
32+
33+
impl<K, V> Slab<K, V> {
34+
#[inline]
35+
pub fn new() -> Self {
36+
Self(HashMap::new())
37+
}
38+
}
39+
40+
impl<K: Eq + Hash, V> Slab<K, V> {
41+
#[inline]
42+
pub fn take<Q>(&mut self, key: &Q) -> Option<V>
43+
where
44+
K: Borrow<Q>,
45+
Q: ?Sized + Hash + Eq,
46+
{
47+
self.0.get_mut(key).and_then(|pool| pool.pop())
48+
}
49+
50+
#[inline]
51+
pub fn put(&mut self, key: K, value: V) {
52+
self.0.entry(key).or_default().push(value);
53+
}
54+
}

models/llama/common-cpu/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
88

99
[dependencies]
1010
llama.path = "../common"
11+
common.workspace = true
1112
operators = { workspace = true, features = ["common-cpu"] }
1213

1314
[dev-dependencies]

models/llama/common-cpu/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
use common::Contiguous;
12
use llama::{
23
ext::ggml_quants::{self, digit_layout::DigitLayout, f16, DataBlock, QuantExt},
3-
BlkWeight, Contiguous, LlamaBlkStorage, LlamaStorage, Tensor,
4+
BlkWeight, LlamaBlkStorage, LlamaStorage, Tensor,
45
TensorUsage::Computation,
56
WeightLoader,
67
};

models/llama/common/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use gguf::ggml_quants::digit_layout::DigitLayout;
66
use std::ops::{Range, RangeBounds};
77

88
pub use args::{Args as LlamaArgs, Request as LlamaRequest};
9-
pub use common::Contiguous;
109
pub use compute::{BlkWeight, LlamaWorker, Operators, WeightLoader};
1110
pub use storage::{BlkStorage as LlamaBlkStorage, Storage as LlamaStorage};
1211
pub use tensor::{RandomSample, Tensor};

models/llama/infini/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
88

99
[dependencies]
1010
llama.path = "../common"
11+
common.workspace = true
1112
operators = { workspace = true, features = ["infini"] }
1213

1314
[build-dependencies]

models/llama/infini/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![cfg(detected)]
22

3-
use llama::{BlkWeight, Contiguous, LlamaBlkStorage, LlamaStorage, Tensor, WeightLoader};
3+
use common::Contiguous;
4+
use llama::{BlkWeight, LlamaBlkStorage, LlamaStorage, Tensor, WeightLoader};
45
use operators::{
56
all_reduce::{infini::Operator as InfiniAllReduce, AllReduce},
67
infini::{Device, InfiniNode},

models/llama/nvidia-gpu/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
88

99
[dependencies]
1010
llama.path = "../common"
11+
common.workspace = true
12+
log.workspace = true
1113
operators = { workspace = true, features = ["nvidia-gpu"] }
1214

1315
[build-dependencies]
@@ -17,5 +19,4 @@ search-cuda-tools.workspace = true
1719
[dev-dependencies]
1820
test-utils = { workspace = true, features = ["llama"] }
1921
gguf.workspace = true
20-
log.workspace = true
2122
regex.workspace = true

models/llama/nvidia-gpu/src/lib.rs

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#![cfg(driver_detected)]
22

3-
use llama::{BlkWeight, Contiguous, LlamaBlkStorage, LlamaStorage, Tensor, WeightLoader};
3+
use common::{Contiguous, Slab};
4+
use llama::{BlkWeight, LlamaBlkStorage, LlamaStorage, Tensor, WeightLoader};
5+
use log::trace;
46
use operators::{
57
all_reduce::{AllReduce, NonAllReduce},
6-
cuda::{memcpy_d2h, CurrentCtx, DevByte, DevMem, Event, HostMem, Stream},
8+
cuda::{memcpy_d2h, AsRaw, CurrentCtx, DevByte, DevMem, Event, HostMem, Stream},
79
nvidia_gpu::Gpu,
810
random_sample::nvidia_gpu::Operator as RandomSampleGpu,
911
rearrange::nvidia_gpu::Operator as Rearrange,
@@ -15,6 +17,7 @@ use std::{
1517
mem::replace,
1618
ops::{Deref, RangeBounds},
1719
rc::Rc,
20+
time::Instant,
1821
};
1922

2023
pub struct Operators<N = Gpu, R = NonAllReduce<Gpu, Rearrange>>(PhantomData<(N, R)>);
@@ -157,11 +160,14 @@ impl<'blk> Weights<'blk> {
157160
) -> Self {
158161
assert!(pool_size > 0);
159162
let stream = Rc::new(ctx.stream());
163+
let igpu = unsafe { ctx.dev().as_raw() };
164+
let mut slab = Slab::new();
160165
let blks = if pool_size < model.meta.nblk {
161166
let mut blks_host = model.blocks[0]
162167
.as_ref()
163168
.map(|_| Vec::with_capacity(model.meta.nblk));
164-
for blk in model.blocks.iter() {
169+
for (iblk, blk) in model.blocks.iter().enumerate() {
170+
let time = Instant::now();
165171
let blk = blk
166172
.distribute(&model.meta, range.clone(), count, |len| {
167173
ctx.malloc_host::<u8>(len)
@@ -188,6 +194,7 @@ impl<'blk> Weights<'blk> {
188194
ffn_gate_up
189195
ffn_down
190196
}
197+
trace!("blk{iblk} loaded to gpu{igpu} in {:?}", time.elapsed())
191198
}
192199
blks_host.map(|vec| {
193200
let roll_cache = vec
@@ -206,18 +213,26 @@ impl<'blk> Weights<'blk> {
206213
let mut blks_dev = model.blocks[0]
207214
.as_ref()
208215
.map(|_| Vec::with_capacity(model.meta.nblk));
209-
for blk in &model.blocks {
210-
let blk = blk.distribute(&model.meta, range.clone(), count, |len| {
211-
ctx.malloc_host::<u8>(len)
216+
for (iblk, blk) in model.blocks.iter().enumerate() {
217+
let blk = blk.distribute(&model.meta, range.clone(), count, |size| {
218+
slab.take(&size)
219+
.unwrap_or_else(|| ctx.malloc_host::<u8>(size))
212220
});
213221
let loader = loader
214222
.get_or_insert_with(|| blk.as_ref().map(|s| H2DLoader::new(s.len(), &stream)));
215223

216224
macro_rules! load {
217225
($( $ident:ident )+ ) => {
218-
$({ blks_dev.$ident.push(loader.$ident.load(blk.$ident, &stream)); })+
226+
$(
227+
let (dev, host) = loader.$ident.load(blk.$ident, &stream);
228+
if let Some(host) = host {
229+
slab.put(host.len(), host)
230+
}
231+
blks_dev.$ident.push(dev);
232+
)+
219233
};
220234
}
235+
let time = Instant::now();
221236
load! {
222237
attn_norm
223238
attn_qkv
@@ -226,6 +241,7 @@ impl<'blk> Weights<'blk> {
226241
ffn_gate_up
227242
ffn_down
228243
}
244+
trace!("blk{iblk} loaded to gpu{igpu} in {:?}", time.elapsed())
229245
}
230246
blks_dev.map(|vec| Cache::Static(vec.into_boxed_slice()))
231247
};
@@ -253,15 +269,25 @@ impl<'ctx> H2DLoader<'ctx> {
253269
}
254270
}
255271

256-
fn load(&mut self, host: Contiguous<HostMem<'ctx>>, stream: &Stream<'ctx>) -> DevMem<'ctx> {
272+
fn load(
273+
&mut self,
274+
host: Contiguous<HostMem<'ctx>>,
275+
stream: &Stream<'ctx>,
276+
) -> (DevMem<'ctx>, Option<HostMem<'ctx>>) {
257277
self.event.synchronize();
258-
match host {
259-
Contiguous::Borrowed(host) => self.host.copy_from_slice(host),
260-
Contiguous::Owned(host) => self.host = host,
278+
let cache = match host {
279+
Contiguous::Borrowed(host) => {
280+
self.host.copy_from_slice(host);
281+
None
282+
}
283+
Contiguous::Owned(host) => Some(replace(&mut self.host, host)),
261284
};
262285
stream.memcpy_h2d(&mut self.dev, &self.host);
263286
self.event = stream.record();
264-
replace(&mut self.dev, stream.malloc::<u8>(self.host.len()))
287+
(
288+
replace(&mut self.dev, stream.malloc::<u8>(self.host.len())),
289+
cache,
290+
)
265291
}
266292
}
267293

models/llama/nvidia-gpu/src/nccl_parallel.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use regex::Regex;
1414
use std::{
1515
iter::zip,
1616
slice::{from_raw_parts, from_raw_parts_mut},
17-
thread,
17+
thread, u64,
1818
};
1919
use test_utils::{test_infer_paralle, Inference, Task, TokenizerAndPrompt, WorkerSeed};
2020

@@ -88,6 +88,11 @@ fn test_infer() {
8888
let WorkerSeed { node, tasks } = seed;
8989
node.processor().apply(|ctx| {
9090
let stream = ctx.stream();
91+
let (free, _) = ctx.mem_info();
92+
93+
ctx.dev().set_mempool_threshold(u64::MAX);
94+
let _ = stream.malloc::<u8>((free.0 >> 30).saturating_sub(1) << 30);
95+
9196
info!("worker[{id}] loading weights...");
9297
let weights = Weights::new(model, range, count, usize::MAX, ctx);
9398
let mut worker = Worker::new(id, &node, meta.clone(), weights, id == 0);

0 commit comments

Comments
 (0)