@@ -3,7 +3,7 @@ use gguf::GGufModel;
33use llama:: { ext:: ggml_quants:: f16, LlamaRequest , LlamaStorage , LlamaWorker , Tensor } ;
44use operators:: {
55 all_reduce:: common_cpu:: Operator as AllReduce ,
6- common_cpu:: { Cpu , InprocNode , ThisThread } ,
6+ common_cpu:: { InprocNode , ThisThread } ,
77 random_sample:: { KVPair , SampleArgs } ,
88 Blob ,
99} ;
@@ -12,10 +12,7 @@ use std::{
1212 iter:: zip,
1313 ptr:: copy_nonoverlapping,
1414 slice:: from_raw_parts_mut,
15- sync:: {
16- mpsc:: { Receiver , Sender } ,
17- Arc , Barrier ,
18- } ,
15+ sync:: mpsc:: { Receiver , Sender } ,
1916 thread,
2017} ;
2118use test_utils:: { Inference , TokenizerAndPrompt } ;
@@ -52,13 +49,11 @@ fn test_infer() {
5249 println ! ( "{sample_args:?}" ) ;
5350
5451 let lens = match devices {
55- Some ( devices) => {
56- let regex = Regex :: new ( r"\d+" ) . unwrap ( ) ;
57- regex
58- . find_iter ( & devices)
59- . map ( |c| c. as_str ( ) . parse :: < usize > ( ) . unwrap ( ) )
60- . collect :: < Vec < _ > > ( )
61- }
52+ Some ( devices) => Regex :: new ( r"\d+" )
53+ . unwrap ( )
54+ . find_iter ( & devices)
55+ . map ( |c| c. as_str ( ) . parse :: < usize > ( ) . unwrap ( ) )
56+ . collect :: < Vec < _ > > ( ) ,
6257 None => vec ! [ 1 ] ,
6358 } ;
6459 println ! ( "distribution: {lens:?}" ) ;
@@ -87,25 +82,27 @@ fn test_infer() {
8782 meta. dh ,
8883 & ThisThread ,
8984 ) ;
85+
86+ let sample = RandomSample :: new ( & node) ;
87+ let indices = RandomSample :: build_indices ( model. meta . nvoc , & ThisThread ) ;
88+ let mut pair = KVPair :: new ( 0 , f16:: ZERO ) ;
89+ let mut pairs = Tensor :: kv_pair_vec ( 1 , |_| unsafe {
90+ from_raw_parts_mut ( & mut pair as * mut _ as * mut u8 , size_of_val ( & pair) )
91+ } ) ;
92+
9093 for task in tasks {
9194 let Task {
9295 nt,
9396 pos,
9497 embd,
95- logits,
96- barrier,
98+ next,
9799 } = task;
98100 let mut embd = meta. embd ( nt) . map ( |size| {
99101 let mut blob = Blob :: new ( size) ;
100102 unsafe { copy_nonoverlapping ( embd, blob. as_mut_ptr ( ) , size) } ;
101103 blob
102104 } ) ;
103- let mut logits = if i == 0 {
104- meta. logits ( 1 )
105- . map ( |size| unsafe { from_raw_parts_mut ( logits, size) } )
106- } else {
107- meta. logits ( 0 ) . map ( |_| & mut [ ] [ ..] )
108- } ;
105+ let mut logits = meta. logits ( if i == 0 { 1 } else { 0 } ) . map ( Blob :: new) ;
109106 worker
110107 . launch (
111108 llama:: LlamaArgs {
@@ -126,17 +123,27 @@ fn test_infer() {
126123 & ThisThread ,
127124 )
128125 . unwrap ( ) ;
129- barrier. wait ( ) ;
126+ if i == 0 {
127+ sample
128+ . launch (
129+ & mut pairs,
130+ & logits,
131+ & indices,
132+ sample_args,
133+ & mut [ ] ,
134+ & ThisThread ,
135+ )
136+ . unwrap ( ) ;
137+ next. send ( pair. idx ( ) as _ ) . unwrap ( )
138+ }
130139 }
131140 } ) )
132141 } )
133142 . collect :: < Vec < _ > > ( ) ;
134143
135- let sample = RandomSample :: new ( & Cpu ) ;
136- let indices = RandomSample :: build_indices ( model. meta . nvoc , & ThisThread ) ;
144+ let ( next, next_recv) = std:: sync:: mpsc:: channel ( ) ;
137145 test_utils:: test_infer ( eos, tokenizer, & prompt, max_steps, |input, pos| {
138146 let mut embd = model. meta . embd ( input. len ( ) ) . map ( Blob :: new) ;
139- let mut logits = model. meta . logits ( 1 ) . map ( Blob :: new) ;
140147
141148 let d = embd. get ( ) . len ( ) / input. len ( ) ;
142149 for ( i, & tok) in input. iter ( ) . enumerate ( ) {
@@ -145,49 +152,28 @@ fn test_infer() {
145152 }
146153 let embd = embd. take ( ) ;
147154
148- let barrier = Arc :: new ( Barrier :: new ( senders. len ( ) + 1 ) ) ;
149155 for sender in & senders {
150156 sender
151157 . send ( Task {
152158 nt : input. len ( ) ,
153159 pos,
154160 embd : embd. as_ptr ( ) ,
155- logits : logits. get_mut ( ) . as_mut_ptr ( ) ,
156- barrier : barrier. clone ( ) ,
161+ next : next. clone ( ) ,
157162 } )
158163 . unwrap ( ) ;
159164 }
160- barrier. wait ( ) ;
161-
162- let mut pair = KVPair :: new ( 0 , f16:: ZERO ) ;
163- let mut pairs = Tensor :: kv_pair_vec ( 1 , |_| unsafe {
164- from_raw_parts_mut ( & mut pair as * mut _ as _ , size_of_val ( & pair) )
165- } ) ;
166-
167- sample
168- . launch (
169- & mut pairs,
170- & logits,
171- & indices,
172- sample_args,
173- & mut [ ] ,
174- & ThisThread ,
175- )
176- . unwrap ( ) ;
177-
178- pair. idx ( ) as _
165+ next_recv. recv ( ) . unwrap ( )
179166 } ) ;
180167
181- drop ( senders) ;
168+ drop ( senders)
182169 } )
183170}
184171
185172struct Task {
186173 nt : usize ,
187174 pos : usize ,
188175 embd : * const u8 ,
189- logits : * mut u8 ,
190- barrier : Arc < Barrier > ,
176+ next : Sender < u32 > ,
191177}
192178
193179unsafe impl Send for Task { }
0 commit comments