Skip to content

Commit 37c0d87

Browse files
committed
feat(nv): 初步 nv 分布式推理
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent b59edbf commit 37c0d87

File tree

8 files changed

+281
-54
lines changed

8 files changed

+281
-54
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ env_logger = "0.11"
3030
build-script-cfg = "0.0"
3131

3232
ndarray-layout = { git = "https://github.com/YdrMaster/ndarray-layout", rev = "f1fdd24" }
33-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "923949f", default-features = false }
33+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "b9e6fdd", default-features = false }
3434

3535
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "6846d52" }
3636
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "136c30b" }

models/llama/common-cpu/src/infer.rs

Lines changed: 35 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use gguf::GGufModel;
33
use llama::{ext::ggml_quants::f16, LlamaRequest, LlamaStorage, LlamaWorker, Tensor};
44
use operators::{
55
all_reduce::common_cpu::Operator as AllReduce,
6-
common_cpu::{Cpu, InprocNode, ThisThread},
6+
common_cpu::{InprocNode, ThisThread},
77
random_sample::{KVPair, SampleArgs},
88
Blob,
99
};
@@ -12,10 +12,7 @@ use std::{
1212
iter::zip,
1313
ptr::copy_nonoverlapping,
1414
slice::from_raw_parts_mut,
15-
sync::{
16-
mpsc::{Receiver, Sender},
17-
Arc, Barrier,
18-
},
15+
sync::mpsc::{Receiver, Sender},
1916
thread,
2017
};
2118
use test_utils::{Inference, TokenizerAndPrompt};
@@ -52,13 +49,11 @@ fn test_infer() {
5249
println!("{sample_args:?}");
5350

5451
let lens = match devices {
55-
Some(devices) => {
56-
let regex = Regex::new(r"\d+").unwrap();
57-
regex
58-
.find_iter(&devices)
59-
.map(|c| c.as_str().parse::<usize>().unwrap())
60-
.collect::<Vec<_>>()
61-
}
52+
Some(devices) => Regex::new(r"\d+")
53+
.unwrap()
54+
.find_iter(&devices)
55+
.map(|c| c.as_str().parse::<usize>().unwrap())
56+
.collect::<Vec<_>>(),
6257
None => vec![1],
6358
};
6459
println!("distribution: {lens:?}");
@@ -87,25 +82,27 @@ fn test_infer() {
8782
meta.dh,
8883
&ThisThread,
8984
);
85+
86+
let sample = RandomSample::new(&node);
87+
let indices = RandomSample::build_indices(model.meta.nvoc, &ThisThread);
88+
let mut pair = KVPair::new(0, f16::ZERO);
89+
let mut pairs = Tensor::kv_pair_vec(1, |_| unsafe {
90+
from_raw_parts_mut(&mut pair as *mut _ as *mut u8, size_of_val(&pair))
91+
});
92+
9093
for task in tasks {
9194
let Task {
9295
nt,
9396
pos,
9497
embd,
95-
logits,
96-
barrier,
98+
next,
9799
} = task;
98100
let mut embd = meta.embd(nt).map(|size| {
99101
let mut blob = Blob::new(size);
100102
unsafe { copy_nonoverlapping(embd, blob.as_mut_ptr(), size) };
101103
blob
102104
});
103-
let mut logits = if i == 0 {
104-
meta.logits(1)
105-
.map(|size| unsafe { from_raw_parts_mut(logits, size) })
106-
} else {
107-
meta.logits(0).map(|_| &mut [][..])
108-
};
105+
let mut logits = meta.logits(if i == 0 { 1 } else { 0 }).map(Blob::new);
109106
worker
110107
.launch(
111108
llama::LlamaArgs {
@@ -126,17 +123,27 @@ fn test_infer() {
126123
&ThisThread,
127124
)
128125
.unwrap();
129-
barrier.wait();
126+
if i == 0 {
127+
sample
128+
.launch(
129+
&mut pairs,
130+
&logits,
131+
&indices,
132+
sample_args,
133+
&mut [],
134+
&ThisThread,
135+
)
136+
.unwrap();
137+
next.send(pair.idx() as _).unwrap()
138+
}
130139
}
131140
}))
132141
})
133142
.collect::<Vec<_>>();
134143

135-
let sample = RandomSample::new(&Cpu);
136-
let indices = RandomSample::build_indices(model.meta.nvoc, &ThisThread);
144+
let (next, next_recv) = std::sync::mpsc::channel();
137145
test_utils::test_infer(eos, tokenizer, &prompt, max_steps, |input, pos| {
138146
let mut embd = model.meta.embd(input.len()).map(Blob::new);
139-
let mut logits = model.meta.logits(1).map(Blob::new);
140147

141148
let d = embd.get().len() / input.len();
142149
for (i, &tok) in input.iter().enumerate() {
@@ -145,49 +152,28 @@ fn test_infer() {
145152
}
146153
let embd = embd.take();
147154

148-
let barrier = Arc::new(Barrier::new(senders.len() + 1));
149155
for sender in &senders {
150156
sender
151157
.send(Task {
152158
nt: input.len(),
153159
pos,
154160
embd: embd.as_ptr(),
155-
logits: logits.get_mut().as_mut_ptr(),
156-
barrier: barrier.clone(),
161+
next: next.clone(),
157162
})
158163
.unwrap();
159164
}
160-
barrier.wait();
161-
162-
let mut pair = KVPair::new(0, f16::ZERO);
163-
let mut pairs = Tensor::kv_pair_vec(1, |_| unsafe {
164-
from_raw_parts_mut(&mut pair as *mut _ as _, size_of_val(&pair))
165-
});
166-
167-
sample
168-
.launch(
169-
&mut pairs,
170-
&logits,
171-
&indices,
172-
sample_args,
173-
&mut [],
174-
&ThisThread,
175-
)
176-
.unwrap();
177-
178-
pair.idx() as _
165+
next_recv.recv().unwrap()
179166
});
180167

181-
drop(senders);
168+
drop(senders)
182169
})
183170
}
184171

185172
struct Task {
186173
nt: usize,
187174
pos: usize,
188175
embd: *const u8,
189-
logits: *mut u8,
190-
barrier: Arc<Barrier>,
176+
next: Sender<u32>,
191177
}
192178

193179
unsafe impl Send for Task {}

models/llama/common/src/compute.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,9 @@ where
264264
self.all_reduce(&mut x, workspace, queue_alloc)?;
265265
}
266266
}
267+
if logits.shape()[0] == 0 {
268+
return Ok(());
269+
}
267270

268271
// 集中要采样的 token
269272
// NOTICE: 输入之前将请求按 seq len 升序排列可降低移动开销

models/llama/nvidia-gpu/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ search-cuda-tools.workspace = true
1717
[dev-dependencies]
1818
test-utils.workspace = true
1919
gguf.workspace = true
20+
regex.workspace = true

models/llama/nvidia-gpu/build.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
fn main() {
22
use build_script_cfg::Cfg;
3-
use search_cuda_tools::find_cuda_root;
3+
use search_cuda_tools::{find_cuda_root, find_nccl_root};
44

55
let cfg = Cfg::new("hw_detected");
6+
let nccl = Cfg::new("nccl_detected");
67
if find_cuda_root().is_some() {
78
cfg.define();
9+
if find_nccl_root().is_some() {
10+
nccl.define();
11+
}
812
}
913
}

models/llama/nvidia-gpu/src/infer.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ fn test_infer() {
2828
else {
2929
return;
3030
};
31-
32-
let roll_cache_size = load_roll_cache_size();
33-
println!("roll_cache_size: {roll_cache_size}");
3431
let gguf = GGufModel::read(model.iter().map(|s| &**s));
3532

3633
let TokenizerAndPrompt {
@@ -45,6 +42,9 @@ fn test_infer() {
4542
let sample_args = SampleArgs::new(temperature, top_p, top_k).expect("invalid sample args");
4643
println!("{sample_args:?}");
4744

45+
let roll_cache_size = load_roll_cache_size();
46+
println!("roll_cache_size: {roll_cache_size}");
47+
4848
let gpu = match cuda::init() {
4949
Ok(()) => Device::new(0),
5050
Err(NoDevice) => return,

models/llama/nvidia-gpu/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,6 @@ impl<'ctx> WeightLoader for Weights<'ctx> {
317317

318318
#[cfg(test)]
319319
mod infer;
320+
321+
#[cfg(all(test, nccl_detected))]
322+
mod nccl_parallel;

0 commit comments

Comments
 (0)