11use crate :: { Operators , RandomSample , Weights } ;
22use gguf:: GGufModel ;
3- use llama:: {
4- ext:: ggml_quants:: f16, LlamaArgs , LlamaMeta , LlamaRequest , LlamaStorage , LlamaWorker , Tensor ,
5- } ;
3+ use llama:: { ext:: ggml_quants:: f16, LlamaRequest , LlamaStorage , LlamaWorker , Tensor } ;
64use operators:: {
75 infini_rt:: { self , Device , DeviceType :: DEVICE_CPU } ,
86 random_sample:: { KVPair , SampleArgs } ,
7+ TopoNode ,
8+ } ;
9+ use regex:: Regex ;
10+ use std:: {
11+ iter:: zip,
12+ slice:: { from_raw_parts, from_raw_parts_mut} ,
13+ thread,
914} ;
10- use std:: { slice:: from_raw_parts_mut, thread, usize} ;
11- use test_utils:: { Inference , TokenizerAndPrompt } ;
15+ use test_utils:: { test_infer_paralle, Inference , Task , TokenizerAndPrompt , WorkerSeed } ;
1216
1317type Worker < ' w > = LlamaWorker < Operators , Weights > ;
1418
1519#[ test]
1620fn test_infer ( ) {
1721 let Some ( Inference {
1822 model,
23+ devices,
1924 prompt,
2025 as_user,
2126 temperature,
2227 top_p,
2328 top_k,
2429 max_steps,
25- ..
2630 } ) = Inference :: load ( )
2731 else {
2832 return ;
@@ -41,83 +45,122 @@ fn test_infer() {
4145 let sample_args = SampleArgs :: new ( temperature, top_p, top_k) . expect ( "invalid sample args" ) ;
4246 println ! ( "{sample_args:?}" ) ;
4347
44- infini_rt:: init ( DEVICE_CPU ) ;
45- let device = Device {
46- ty : DEVICE_CPU ,
47- id : 0 ,
48- } ;
49-
50- let meta = & model. meta ;
51- let & LlamaMeta {
52- dt_embd,
53- nctx,
54- nvoc,
55- dh,
56- ..
57- } = meta;
48+ let devices = devices
49+ . map ( |devices| {
50+ Regex :: new ( r"\d+" )
51+ . unwrap ( )
52+ . find_iter ( & devices)
53+ . map ( |c| c. as_str ( ) . parse ( ) . unwrap ( ) )
54+ . collect ( )
55+ } )
56+ . unwrap_or_else ( || vec ! [ 0 ] ) ;
57+ let lens = vec ! [ 1 ; devices. len( ) ] ;
58+ let count = devices. len ( ) ;
59+ println ! ( "distribution: {devices:?}" ) ;
5860
61+ infini_rt:: init ( DEVICE_CPU ) ;
62+ let ( seeds, senders) = WorkerSeed :: new (
63+ devices
64+ . into_iter ( )
65+ . map ( |id| Device { ty : DEVICE_CPU , id } )
66+ . collect ( ) ,
67+ ) ;
5968 thread:: scope ( |s| {
60- let sample = s. spawn ( move || {
61- let mut sample = RandomSample :: new ( & device) ;
62- sample. scheme ( dt_embd, nvoc) . unwrap ( ) ;
63- sample
64- } ) ;
65- let stream = device. stream ( ) ;
66-
67- let token_embd = device. from_host ( model. token_embd ) ;
68- let weights = Weights :: new ( & model, .., 1 , & stream) ;
69- let mut worker = Worker :: new ( & device, meta. clone ( ) , weights, true ) ;
70- let mut cache = meta. kv_cache ( nctx) . map ( |size| stream. malloc :: < u8 > ( size) ) ;
71- let sin_cos = <Operators as llama:: Operators >:: build_sin_cos ( dt_embd, nctx, dh, & stream) ;
72- let indices = RandomSample :: build_indices ( nvoc, & stream) ;
73-
74- let sample = sample. join ( ) . unwrap ( ) ;
75- test_utils:: test_infer ( eos, tokenizer, & prompt, max_steps, |input, pos| {
76- let mut embd = meta. embd ( input. len ( ) ) . map ( |len| stream. malloc :: < u8 > ( len) ) ;
77- let mut logits = meta. logits ( 1 ) . map ( |len| stream. malloc :: < u8 > ( len) ) ;
78-
79- let d = embd. get ( ) . len ( ) / input. len ( ) ;
80- for ( i, & tok) in input. iter ( ) . enumerate ( ) {
81- stream. memcpy_d2d (
82- & mut embd. get_mut ( ) [ i * d..] [ ..d] ,
83- & token_embd[ tok as usize * d..] [ ..d] ,
84- )
85- }
86-
87- worker
88- . launch (
89- LlamaArgs {
90- embd : embd. map_slice_mut ( ) ,
91- logits : logits. map_slice_mut ( ) ,
92- sin_cos : sin_cos. map_slice ( ) ,
93- requests : vec ! [ LlamaRequest {
94- cache: cache. map_slice_mut( ) ,
95- seq_len: input. len( ) ,
96- out_len: 1 ,
69+ let _workers = zip ( lens, seeds)
70+ . enumerate ( )
71+ . scan ( 0 , |start, ( i, ( len, seed) ) | {
72+ let range = * start..* start + len;
73+ * start = range. end ;
74+
75+ let mut meta = model. meta . clone ( ) ;
76+ meta. distribute ( range. clone ( ) , count) ;
77+
78+ let model = & model;
79+ Some ( s. spawn ( move || {
80+ let WorkerSeed { node, tasks } = seed;
81+ let device = node. processor ( ) ;
82+ let stream = device. stream ( ) ;
83+ let weights = Weights :: new ( model, range, count, & stream) ;
84+ let mut worker = Worker :: new ( & node, meta. clone ( ) , weights, i == 0 ) ;
85+ let mut cache = meta
86+ . kv_cache ( meta. nctx )
87+ . map ( |size| stream. malloc :: < u8 > ( size) ) ;
88+ let sin_cos = <Operators as llama:: Operators >:: build_sin_cos (
89+ meta. dt_embd ,
90+ meta. nctx ,
91+ meta. dh ,
92+ & stream,
93+ ) ;
94+
95+ let sample = RandomSample :: new ( & node) ;
96+ let indices = RandomSample :: build_indices ( model. meta . nvoc , & stream) ;
97+ let mut pair = KVPair :: new ( 0 , f16:: ZERO ) ;
98+ let mut pairs = Tensor :: kv_pair_vec ( 1 , |size| stream. malloc :: < u8 > ( size) ) ;
99+
100+ for task in tasks {
101+ let Task {
102+ nt,
97103 pos,
98- } ] ,
99- num_tokens : input. len ( ) ,
100- max_seq_len : input. len ( ) ,
101- max_att_len : pos + input. len ( ) ,
102- } ,
103- & mut [ ] ,
104- & stream,
105- )
106- . unwrap ( ) ;
107-
108- let mut pairs = Tensor :: kv_pair_vec ( 1 , |size| stream. malloc :: < u8 > ( size) ) ;
109-
110- sample
111- . launch ( & mut pairs, & logits, & indices, sample_args, & mut [ ] , & stream)
112- . unwrap ( ) ;
113-
114- let mut pair = KVPair :: new ( 0 , f16:: ZERO ) ;
115- device. memcpy_d2h (
116- unsafe { from_raw_parts_mut ( & mut pair as * mut _ as * mut u8 , size_of_val ( & pair) ) } ,
117- pairs. get ( ) ,
118- ) ;
119-
120- pair. idx ( ) as _
121- } ) ;
122- } ) ;
104+ embd,
105+ next,
106+ } = task;
107+ let mut embd = meta
108+ . embd ( nt)
109+ . map ( |size| stream. from_host ( unsafe { from_raw_parts ( embd, size) } ) ) ;
110+ let mut logits = meta
111+ . logits ( if i == 0 { 1 } else { 0 } )
112+ . map ( |size| stream. malloc :: < u8 > ( size) ) ;
113+ worker
114+ . launch (
115+ llama:: LlamaArgs {
116+ embd : embd. map_slice_mut ( ) ,
117+ logits : logits. map_slice_mut ( ) ,
118+ sin_cos : sin_cos. map_slice ( ) ,
119+ requests : vec ! [ LlamaRequest {
120+ cache: cache. map_slice_mut( ) ,
121+ seq_len: nt,
122+ out_len: if i == 0 { 1 } else { 0 } ,
123+ pos,
124+ } ] ,
125+ num_tokens : nt,
126+ max_seq_len : nt,
127+ max_att_len : nt + pos,
128+ } ,
129+ & mut [ ] ,
130+ & stream,
131+ )
132+ . unwrap ( ) ;
133+ if i == 0 {
134+ sample
135+ . launch (
136+ & mut pairs,
137+ & logits,
138+ & indices,
139+ sample_args,
140+ & mut [ ] ,
141+ & stream,
142+ )
143+ . unwrap ( ) ;
144+
145+ stream. synchronize ( ) ;
146+ device. memcpy_d2h (
147+ unsafe {
148+ from_raw_parts_mut (
149+ & mut pair as * mut _ as * mut u8 ,
150+ pairs. get ( ) . len ( ) ,
151+ )
152+ } ,
153+ pairs. get ( ) ,
154+ ) ;
155+
156+ next. send ( pair. idx ( ) as _ ) . unwrap ( )
157+ }
158+ }
159+ } ) )
160+ } )
161+ . collect :: < Vec < _ > > ( ) ;
162+
163+ let senders = senders. into_boxed_slice ( ) ;
164+ test_infer_paralle ( & model, senders, eos, tokenizer, & prompt, max_steps)
165+ } )
123166}
0 commit comments