@@ -7,16 +7,15 @@ use operators::{
77 nccl:: CommunicatorGroup ,
88 nvidia_gpu:: NcclNode ,
99 random_sample:: { KVPair , SampleArgs } ,
10- Blob , TopoNode ,
10+ TopoNode ,
1111} ;
1212use regex:: Regex ;
1313use std:: {
1414 iter:: zip,
1515 slice:: { from_raw_parts, from_raw_parts_mut} ,
16- sync:: mpsc:: { Receiver , Sender } ,
17- thread, usize,
16+ thread,
1817} ;
19- use test_utils:: { Inference , TokenizerAndPrompt } ;
18+ use test_utils:: { test_infer_paralle , Inference , Task , TokenizerAndPrompt , WorkerSeed } ;
2019
2120type Worker < ' w > = LlamaWorker < Operators < NcclNode , AllReduce > , Weights < ' w > > ;
2221
@@ -49,21 +48,27 @@ fn test_infer() {
4948 let sample_args = SampleArgs :: new ( temperature, top_p, top_k) . expect ( "invalid sample args" ) ;
5049 println ! ( "{sample_args:?}" ) ;
5150
52- let devices = match devices {
53- Some ( devices) => Regex :: new ( r"\d+" )
54- . unwrap ( )
55- . find_iter ( & devices)
56- . map ( |c| c. as_str ( ) . parse ( ) . unwrap ( ) )
57- . collect :: < Vec < _ > > ( ) ,
58- None => vec ! [ 0 ] ,
59- } ;
60- println ! ( "distribution: {devices:?}" ) ;
61-
51+ let devices = devices
52+ . map ( |devices| {
53+ Regex :: new ( r"\d+" )
54+ . unwrap ( )
55+ . find_iter ( & devices)
56+ . map ( |c| c. as_str ( ) . parse ( ) . unwrap ( ) )
57+ . collect ( )
58+ } )
59+ . unwrap_or_else ( || vec ! [ 1 ] ) ;
6260 let lens = vec ! [ 1 ; devices. len( ) ] ;
6361 let count = devices. len ( ) ;
62+ println ! ( "distribution: {devices:?}" ) ;
6463
6564 let ( seeds, senders) = match cuda:: init ( ) {
66- Ok ( ( ) ) => WorkerSeed :: new ( & devices) ,
65+ Ok ( ( ) ) => WorkerSeed :: new (
66+ CommunicatorGroup :: new ( & devices)
67+ . into_vec ( )
68+ . into_iter ( )
69+ . map ( |comm| NcclNode :: new ( comm, Default :: default ( ) ) )
70+ . collect ( ) ,
71+ ) ,
6772 Err ( NoDevice ) => return ,
6873 } ;
6974 thread:: scope ( |s| {
@@ -77,7 +82,6 @@ fn test_infer() {
7782 meta. distribute ( range. clone ( ) , count) ;
7883
7984 let model = & model;
80-
8185 Some ( s. spawn ( move || {
8286 let WorkerSeed { node, tasks } = seed;
8387 node. processor ( ) . apply ( |ctx| {
@@ -163,68 +167,7 @@ fn test_infer() {
163167 } )
164168 . collect :: < Vec < _ > > ( ) ;
165169
166- let ( next, next_recv) = std:: sync:: mpsc:: channel ( ) ;
167- test_utils:: test_infer ( eos, tokenizer, & prompt, max_steps, |input, pos| {
168- let mut embd = model. meta . embd ( input. len ( ) ) . map ( Blob :: new) ;
169-
170- let d = embd. get ( ) . len ( ) / input. len ( ) ;
171- for ( i, & tok) in input. iter ( ) . enumerate ( ) {
172- embd. get_mut ( ) [ i * d..] [ ..d]
173- . copy_from_slice ( & model. token_embd [ tok as usize * d..] [ ..d] ) ;
174- }
175- let embd = embd. take ( ) ;
176-
177- for sender in & senders {
178- sender
179- . send ( Task {
180- nt : input. len ( ) ,
181- pos,
182- embd : embd. as_ptr ( ) ,
183- next : next. clone ( ) ,
184- } )
185- . unwrap ( ) ;
186- }
187- next_recv. recv ( ) . unwrap ( )
188- } ) ;
189-
190- drop ( senders)
170+ let senders = senders. into_boxed_slice ( ) ;
171+ test_infer_paralle ( & model, senders, eos, tokenizer, & prompt, max_steps)
191172 } )
192173}
193-
194- struct Task {
195- nt : usize ,
196- pos : usize ,
197- embd : * const u8 ,
198- next : Sender < u32 > ,
199- }
200-
201- unsafe impl Send for Task { }
202-
203- struct WorkerSeed {
204- tasks : Receiver < Task > ,
205- node : NcclNode ,
206- }
207-
208- impl WorkerSeed {
209- fn new ( devices : & [ i32 ] ) -> ( Vec < Self > , Vec < Sender < Task > > ) {
210- let nodes = CommunicatorGroup :: new ( devices)
211- . into_vec ( )
212- . into_iter ( )
213- . map ( |comm| NcclNode :: new ( comm, Default :: default ( ) ) )
214- . collect :: < Vec < _ > > ( ) ;
215- let n = nodes. len ( ) ;
216- let mut tasks = Vec :: with_capacity ( n) ;
217- let mut senders = Vec :: with_capacity ( n) ;
218- for _ in 0 ..n {
219- let ( sender, receiver) = std:: sync:: mpsc:: channel ( ) ;
220- tasks. push ( receiver) ;
221- senders. push ( sender) ;
222- }
223- (
224- zip ( nodes, tasks)
225- . map ( |( node, tasks) | Self { node, tasks } )
226- . collect ( ) ,
227- senders,
228- )
229- }
230- }
0 commit comments