@@ -7,9 +7,9 @@ use memmap2::Mmap;
77use operators:: {
88 common_cpu:: { Cpu , ThisThread } ,
99 random_sample:: { common_cpu:: Operator as CpuOp , KVPair , SampleArgs } ,
10- QueueOf ,
10+ ByteOf , QueueOf ,
1111} ;
12- use std:: slice:: from_raw_parts_mut;
12+ use std:: { ops :: Deref , slice:: from_raw_parts_mut} ;
1313use tensor:: { ArrayLayout , BigEndian , Tensor } ;
1414
1515pub struct Llama {
@@ -62,7 +62,7 @@ impl Llama {
6262 let mut embd_buf = vec ! [ 0u8 ; embd. shape( ) . iter( ) . product:: <usize >( ) * ele] ;
6363 let mut logits_buf = vec ! [ 0u8 ; logits. shape( ) . iter( ) . product:: <usize >( ) * ele] ;
6464
65- let d = embd. shape ( ) [ 1 ] ;
65+ let d = embd. shape ( ) [ 1 ] * ele ;
6666 for ( i, & tok) in input. iter ( ) . enumerate ( ) {
6767 embd_buf[ i * d..] [ ..d] . copy_from_slice ( & self . token_embed [ tok as usize * d..] [ ..d] ) ;
6868 }
@@ -132,6 +132,13 @@ impl llama::Operators for Operators {
132132 type AttnKVCached = op ! ( attention_kv_cached) ;
133133 type Mlp = op ! ( mlp) ;
134134 type Rearrange = op ! ( rearrange) ;
135+
136+ fn debug < T > ( tensor : & Tensor < T > )
137+ where
138+ T : Deref < Target = [ ByteOf < Self :: Hardware > ] > ,
139+ {
140+ println ! ( "{tensor}" ) ;
141+ }
135142}
136143
137144struct Weights {
@@ -174,14 +181,19 @@ impl WeightLoader for Weights {
174181}
175182
176183#[ test]
177- fn test_load ( ) {
178- use gguf:: GGufModel ;
179- use std:: { io:: Write , slice:: from_raw_parts} ;
184+ fn test_infer ( ) {
185+ use gguf:: { GGufMetaMapExt , GGufModel } ;
186+ use std:: {
187+ io:: Write ,
188+ slice:: from_raw_parts,
189+ time:: { Duration , Instant } ,
190+ } ;
180191
181192 let Some ( shards) = test_utils:: map_gguf_files ( ) else {
182193 return ;
183194 } ;
184195 let gguf = GGufModel :: read ( shards. iter ( ) . map ( |s| & * * s) ) ;
196+ let eos = gguf. tokenizer_ggml_eos_token_id ( ) . unwrap ( ) ;
185197 let tokenizer = gguf. tokenizer ( ) ;
186198 let llama =
187199 LlamaStorage :: from_gguf ( & gguf) . map ( & mut |s| unsafe { from_raw_parts ( s. as_ptr ( ) , s. len ( ) ) } ) ;
@@ -194,14 +206,50 @@ fn test_load() {
194206 let mut cache_buf = vec ! [ 0u8 ; cache. shape( ) . iter( ) . product:: <usize >( ) * size_of:: <f16>( ) ] ;
195207
196208 let mut prompt = "Once upon a time," . to_string ( ) ;
209+
210+ print ! ( "{prompt}" ) ;
211+ std:: io:: stdout ( ) . flush ( ) . unwrap ( ) ;
212+
197213 let mut tokens = tokenizer. encode ( & prompt) ;
198- while !tokens. contains ( & 2 ) {
199- let next = llama. infer ( & tokens, & mut cache_buf, 0 ) ;
200- tokens = vec ! [ next] ;
214+ let num_prompt_tokens = tokens. len ( ) ;
215+
216+ let mut prefill = Duration :: ZERO ;
217+ let mut decode = Duration :: ZERO ;
218+
219+ let mut pos = 0 ;
220+ loop {
221+ let time = Instant :: now ( ) ;
222+ let next = llama. infer ( & tokens, & mut cache_buf, pos) ;
223+ let time = time. elapsed ( ) ;
224+
225+ if prefill. is_zero ( ) {
226+ prefill = time;
227+ } else {
228+ decode += time;
229+ }
230+
231+ pos += tokens. len ( ) ;
232+ if next == eos {
233+ break ;
234+ }
201235
202236 let piece = tokenizer. decode ( next) ;
203237 print ! ( "{piece}" ) ;
204238 std:: io:: stdout ( ) . flush ( ) . unwrap ( ) ;
205239 prompt. push_str ( & piece) ;
240+ tokens = vec ! [ next] ;
241+ }
242+
243+ println ! ( ) ;
244+ println ! ( ) ;
245+ print_time ( "total" , prefill + decode, pos) ;
246+ print_time ( "prefill" , prefill, num_prompt_tokens) ;
247+ print_time ( "decode" , decode, pos - num_prompt_tokens) ;
248+
249+ fn print_time ( name : & str , time : Duration , n : usize ) {
250+ println ! (
251+ "{name} : {time:?} for {n} tokens, avg: {:?} per token" ,
252+ time. div_f64( n as _)
253+ ) ;
206254 }
207255}
0 commit comments