@@ -8,7 +8,7 @@ use operators::{
88 nvidia_gpu:: { Config , Gpu } ,
99 random_sample:: { KVPair , SampleArgs } ,
1010} ;
11- use std:: { slice:: from_raw_parts_mut, thread , usize} ;
11+ use std:: { slice:: from_raw_parts_mut, time :: Instant , usize} ;
1212use test_utils:: { load_roll_cache_size, Inference , TokenizerAndPrompt } ;
1313
1414type Worker < ' w > = LlamaWorker < Operators , Weights < ' w > > ;
@@ -60,73 +60,66 @@ fn test_infer() {
6060 ..
6161 } = meta;
6262
63- thread:: scope ( |s| {
64- let sample = s. spawn ( move || {
65- let mut sample = RandomSample :: new ( gpu) ;
66- sample. scheme ( dt_embd, nvoc) . unwrap ( ) ;
67- sample
68- } ) ;
69- gpu. apply ( |ctx| {
70- let stream = ctx. stream ( ) ;
71-
72- let token_embd = stream. from_host ( model. token_embd ) ;
73- let weights = Weights :: new ( & model, .., 1 , roll_cache_size, ctx) ;
74- let mut worker = Worker :: new ( & gpu, meta. clone ( ) , weights, true ) ;
75- let mut cache = meta. kv_cache ( nctx) . map ( |size| stream. malloc :: < u8 > ( size) ) ;
76- let sin_cos =
77- <Operators as llama:: Operators >:: build_sin_cos ( dt_embd, nctx, dh, & stream) ;
78- let indices = RandomSample :: build_indices ( nvoc, & stream) ;
79-
80- let sample = sample. join ( ) . unwrap ( ) ;
81- test_utils:: test_infer ( eos, tokenizer, & prompt, max_steps, |input, pos| {
82- let mut embd = meta. embd ( input. len ( ) ) . map ( |len| stream. malloc :: < u8 > ( len) ) ;
83- let mut logits = meta. logits ( 1 ) . map ( |len| stream. malloc :: < u8 > ( len) ) ;
84-
85- let d = embd. get ( ) . len ( ) / input. len ( ) ;
86- for ( i, & tok) in input. iter ( ) . enumerate ( ) {
87- stream. memcpy_d2d (
88- & mut embd. get_mut ( ) [ i * d..] [ ..d] ,
89- & token_embd[ tok as usize * d..] [ ..d] ,
90- )
91- }
92-
93- worker
94- . launch (
95- LlamaArgs {
96- embd : embd. map_slice_mut ( ) ,
97- logits : logits. map_slice_mut ( ) ,
98- sin_cos : sin_cos. map_slice ( ) ,
99- requests : vec ! [ LlamaRequest {
100- cache: cache. map_slice_mut( ) ,
101- seq_len: input. len( ) ,
102- out_len: 1 ,
103- pos,
104- } ] ,
105- num_tokens : input. len ( ) ,
106- max_seq_len : input. len ( ) ,
107- max_att_len : pos + input. len ( ) ,
108- } ,
109- & mut [ ] ,
110- & stream,
111- )
112- . unwrap ( ) ;
113-
114- let mut pairs = Tensor :: kv_pair_vec ( 1 , |size| stream. malloc :: < u8 > ( size) ) ;
115-
116- sample
117- . launch ( & mut pairs, & logits, & indices, sample_args, & mut [ ] , & stream)
118- . unwrap ( ) ;
119-
120- let mut pair = KVPair :: new ( 0 , f16:: ZERO ) ;
121- memcpy_d2h (
122- unsafe {
123- from_raw_parts_mut ( & mut pair as * mut _ as * mut u8 , size_of_val ( & pair) )
63+ gpu. apply ( |ctx| {
64+ let stream = ctx. stream ( ) ;
65+
66+ let time = Instant :: now ( ) ;
67+ let token_embd = stream. from_host ( model. token_embd ) ;
68+ let weights = Weights :: new ( & model, .., 1 , roll_cache_size, ctx) ;
69+ println ! ( "load weights: {:?}" , time. elapsed( ) ) ;
70+
71+ let mut worker = Worker :: new ( & gpu, meta. clone ( ) , weights, true ) ;
72+ let mut cache = meta. kv_cache ( nctx) . map ( |size| stream. malloc :: < u8 > ( size) ) ;
73+ let sin_cos = <Operators as llama:: Operators >:: build_sin_cos ( dt_embd, nctx, dh, & stream) ;
74+ let indices = RandomSample :: build_indices ( nvoc, & stream) ;
75+ let sample = RandomSample :: new ( gpu) ;
76+
77+ test_utils:: test_infer ( eos, tokenizer, & prompt, max_steps, |input, pos| {
78+ let mut embd = meta. embd ( input. len ( ) ) . map ( |len| stream. malloc :: < u8 > ( len) ) ;
79+ let mut logits = meta. logits ( 1 ) . map ( |len| stream. malloc :: < u8 > ( len) ) ;
80+
81+ let d = embd. get ( ) . len ( ) / input. len ( ) ;
82+ for ( i, & tok) in input. iter ( ) . enumerate ( ) {
83+ stream. memcpy_d2d (
84+ & mut embd. get_mut ( ) [ i * d..] [ ..d] ,
85+ & token_embd[ tok as usize * d..] [ ..d] ,
86+ )
87+ }
88+
89+ worker
90+ . launch (
91+ LlamaArgs {
92+ embd : embd. map_slice_mut ( ) ,
93+ logits : logits. map_slice_mut ( ) ,
94+ sin_cos : sin_cos. map_slice ( ) ,
95+ requests : vec ! [ LlamaRequest {
96+ cache: cache. map_slice_mut( ) ,
97+ seq_len: input. len( ) ,
98+ out_len: 1 ,
99+ pos,
100+ } ] ,
101+ num_tokens : input. len ( ) ,
102+ max_seq_len : input. len ( ) ,
103+ max_att_len : pos + input. len ( ) ,
124104 } ,
125- pairs. get ( ) ,
126- ) ;
105+ & mut [ ] ,
106+ & stream,
107+ )
108+ . unwrap ( ) ;
109+
110+ let mut pairs = Tensor :: kv_pair_vec ( 1 , |size| stream. malloc :: < u8 > ( size) ) ;
111+
112+ sample
113+ . launch ( & mut pairs, & logits, & indices, sample_args, & mut [ ] , & stream)
114+ . unwrap ( ) ;
115+
116+ let mut pair = KVPair :: new ( 0 , f16:: ZERO ) ;
117+ memcpy_d2h (
118+ unsafe { from_raw_parts_mut ( & mut pair as * mut _ as * mut u8 , size_of_val ( & pair) ) } ,
119+ pairs. get ( ) ,
120+ ) ;
127121
128- pair. idx ( ) as _
129- } ) ;
122+ pair. idx ( ) as _
130123 } ) ;
131124 } ) ;
132125}
0 commit comments