Skip to content

Commit 9ae4f93

Browse files
committed
feat(llama-nv): 简化 NV 推理
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 9c66170 commit 9ae4f93

File tree

2 files changed

+60
-67
lines changed

2 files changed

+60
-67
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ itertools = "0.13"
2828
build-script-cfg = "0.0"
2929

3030
ndarray-layout = { git = "https://github.com/YdrMaster/ndarray-layout", rev = "f1fdd24" }
31-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "8c2227a", default-features = false }
31+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "923949f", default-features = false }
3232

3333
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "6846d52" }
3434
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "136c30b" }

models/llama/nvidia-gpu/src/test_infer.rs

Lines changed: 59 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use operators::{
88
nvidia_gpu::{Config, Gpu},
99
random_sample::{KVPair, SampleArgs},
1010
};
11-
use std::{slice::from_raw_parts_mut, thread, usize};
11+
use std::{slice::from_raw_parts_mut, time::Instant, usize};
1212
use test_utils::{load_roll_cache_size, Inference, TokenizerAndPrompt};
1313

1414
type Worker<'w> = LlamaWorker<Operators, Weights<'w>>;
@@ -60,73 +60,66 @@ fn test_infer() {
6060
..
6161
} = meta;
6262

63-
thread::scope(|s| {
64-
let sample = s.spawn(move || {
65-
let mut sample = RandomSample::new(gpu);
66-
sample.scheme(dt_embd, nvoc).unwrap();
67-
sample
68-
});
69-
gpu.apply(|ctx| {
70-
let stream = ctx.stream();
71-
72-
let token_embd = stream.from_host(model.token_embd);
73-
let weights = Weights::new(&model, .., 1, roll_cache_size, ctx);
74-
let mut worker = Worker::new(&gpu, meta.clone(), weights, true);
75-
let mut cache = meta.kv_cache(nctx).map(|size| stream.malloc::<u8>(size));
76-
let sin_cos =
77-
<Operators as llama::Operators>::build_sin_cos(dt_embd, nctx, dh, &stream);
78-
let indices = RandomSample::build_indices(nvoc, &stream);
79-
80-
let sample = sample.join().unwrap();
81-
test_utils::test_infer(eos, tokenizer, &prompt, max_steps, |input, pos| {
82-
let mut embd = meta.embd(input.len()).map(|len| stream.malloc::<u8>(len));
83-
let mut logits = meta.logits(1).map(|len| stream.malloc::<u8>(len));
84-
85-
let d = embd.get().len() / input.len();
86-
for (i, &tok) in input.iter().enumerate() {
87-
stream.memcpy_d2d(
88-
&mut embd.get_mut()[i * d..][..d],
89-
&token_embd[tok as usize * d..][..d],
90-
)
91-
}
92-
93-
worker
94-
.launch(
95-
LlamaArgs {
96-
embd: embd.map_slice_mut(),
97-
logits: logits.map_slice_mut(),
98-
sin_cos: sin_cos.map_slice(),
99-
requests: vec![LlamaRequest {
100-
cache: cache.map_slice_mut(),
101-
seq_len: input.len(),
102-
out_len: 1,
103-
pos,
104-
}],
105-
num_tokens: input.len(),
106-
max_seq_len: input.len(),
107-
max_att_len: pos + input.len(),
108-
},
109-
&mut [],
110-
&stream,
111-
)
112-
.unwrap();
113-
114-
let mut pairs = Tensor::kv_pair_vec(1, |size| stream.malloc::<u8>(size));
115-
116-
sample
117-
.launch(&mut pairs, &logits, &indices, sample_args, &mut [], &stream)
118-
.unwrap();
119-
120-
let mut pair = KVPair::new(0, f16::ZERO);
121-
memcpy_d2h(
122-
unsafe {
123-
from_raw_parts_mut(&mut pair as *mut _ as *mut u8, size_of_val(&pair))
63+
gpu.apply(|ctx| {
64+
let stream = ctx.stream();
65+
66+
let time = Instant::now();
67+
let token_embd = stream.from_host(model.token_embd);
68+
let weights = Weights::new(&model, .., 1, roll_cache_size, ctx);
69+
println!("load weights: {:?}", time.elapsed());
70+
71+
let mut worker = Worker::new(&gpu, meta.clone(), weights, true);
72+
let mut cache = meta.kv_cache(nctx).map(|size| stream.malloc::<u8>(size));
73+
let sin_cos = <Operators as llama::Operators>::build_sin_cos(dt_embd, nctx, dh, &stream);
74+
let indices = RandomSample::build_indices(nvoc, &stream);
75+
let sample = RandomSample::new(gpu);
76+
77+
test_utils::test_infer(eos, tokenizer, &prompt, max_steps, |input, pos| {
78+
let mut embd = meta.embd(input.len()).map(|len| stream.malloc::<u8>(len));
79+
let mut logits = meta.logits(1).map(|len| stream.malloc::<u8>(len));
80+
81+
let d = embd.get().len() / input.len();
82+
for (i, &tok) in input.iter().enumerate() {
83+
stream.memcpy_d2d(
84+
&mut embd.get_mut()[i * d..][..d],
85+
&token_embd[tok as usize * d..][..d],
86+
)
87+
}
88+
89+
worker
90+
.launch(
91+
LlamaArgs {
92+
embd: embd.map_slice_mut(),
93+
logits: logits.map_slice_mut(),
94+
sin_cos: sin_cos.map_slice(),
95+
requests: vec![LlamaRequest {
96+
cache: cache.map_slice_mut(),
97+
seq_len: input.len(),
98+
out_len: 1,
99+
pos,
100+
}],
101+
num_tokens: input.len(),
102+
max_seq_len: input.len(),
103+
max_att_len: pos + input.len(),
124104
},
125-
pairs.get(),
126-
);
105+
&mut [],
106+
&stream,
107+
)
108+
.unwrap();
109+
110+
let mut pairs = Tensor::kv_pair_vec(1, |size| stream.malloc::<u8>(size));
111+
112+
sample
113+
.launch(&mut pairs, &logits, &indices, sample_args, &mut [], &stream)
114+
.unwrap();
115+
116+
let mut pair = KVPair::new(0, f16::ZERO);
117+
memcpy_d2h(
118+
unsafe { from_raw_parts_mut(&mut pair as *mut _ as *mut u8, size_of_val(&pair)) },
119+
pairs.get(),
120+
);
127121

128-
pair.idx() as _
129-
});
122+
pair.idx() as _
130123
});
131124
});
132125
}

0 commit comments

Comments
 (0)