Skip to content

Commit 301cee2

Browse files
committed
perf(nv): 加速大模型加载
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent f16b51a commit 301cee2

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

models/llama/nvidia-gpu/src/lib.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use operators::{
77
nvidia_gpu::Gpu,
88
random_sample::nvidia_gpu::Operator as RandomSampleGpu,
99
rearrange::nvidia_gpu::Operator as Rearrange,
10-
ByteOf, QueueOf, TopoNode,
10+
Blob, ByteOf, QueueOf, TopoNode,
1111
};
1212
use std::{
1313
cell::{RefCell, RefMut},
@@ -207,9 +207,7 @@ impl<'blk> Weights<'blk> {
207207
.as_ref()
208208
.map(|_| Vec::with_capacity(model.meta.nblk));
209209
for blk in &model.blocks {
210-
let blk = blk.distribute(&model.meta, range.clone(), count, |len| {
211-
ctx.malloc_host::<u8>(len)
212-
});
210+
let blk = blk.distribute(&model.meta, range.clone(), count, Blob::new);
213211
let loader = loader
214212
.get_or_insert_with(|| blk.as_ref().map(|s| H2DLoader::new(s.len(), &stream)));
215213

@@ -240,20 +238,20 @@ impl<'blk> Weights<'blk> {
240238

241239
struct H2DLoader<'ctx> {
242240
event: Event<'ctx>,
243-
host: HostMem<'ctx>,
241+
host: Blob,
244242
dev: DevMem<'ctx>,
245243
}
246244

247245
impl<'ctx> H2DLoader<'ctx> {
248246
fn new(size: usize, stream: &Stream<'ctx>) -> Self {
249247
Self {
250248
event: stream.record(),
251-
host: stream.ctx().malloc_host::<u8>(size),
249+
host: Blob::new(size),
252250
dev: stream.malloc::<u8>(size),
253251
}
254252
}
255253

256-
fn load(&mut self, host: Contiguous<HostMem<'ctx>>, stream: &Stream<'ctx>) -> DevMem<'ctx> {
254+
fn load(&mut self, host: Contiguous<Blob>, stream: &Stream<'ctx>) -> DevMem<'ctx> {
257255
self.event.synchronize();
258256
match host {
259257
Contiguous::Borrowed(host) => self.host.copy_from_slice(host),

models/llama/nvidia-gpu/src/nccl_parallel.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use regex::Regex;
1414
use std::{
1515
iter::zip,
1616
slice::{from_raw_parts, from_raw_parts_mut},
17-
thread,
17+
thread, u64,
1818
};
1919
use test_utils::{test_infer_paralle, Inference, Task, TokenizerAndPrompt, WorkerSeed};
2020

@@ -88,6 +88,14 @@ fn test_infer() {
8888
let WorkerSeed { node, tasks } = seed;
8989
node.processor().apply(|ctx| {
9090
let stream = ctx.stream();
91+
92+
let mut free = 0;
93+
let mut total = 0;
94+
cuda::driver!(cuMemGetInfo_v2(&mut free, &mut total));
95+
96+
ctx.dev().set_mempool_threshold(u64::MAX);
97+
let _ = stream.malloc::<u8>((free >> 30).saturating_sub(1) << 30);
98+
9199
info!("worker[{id}] loading weights...");
92100
let weights = Weights::new(model, range, count, usize::MAX, ctx);
93101
let mut worker = Worker::new(id, &node, meta.clone(), weights, id == 0);

0 commit comments

Comments
 (0)