11use crate :: { Operators , RandomSample , Weights } ;
2- use gguf:: GGufModel ;
3- use gpt2:: { ext:: ggml_quants:: f16, GPT2Storage , Gpt2Meta , Gpt2Worker , Tensor } ;
2+ use common:: Distribution ;
3+ use gguf:: ext:: utok;
4+ use gguf:: { GGufModel , Tokenizer } ;
5+ use gpt2:: { ext:: ggml_quants:: f16, GPT2Storage , Gpt2Worker , Tensor } ;
6+ use operators:: common_cpu:: InprocNode ;
47use operators:: {
5- common_cpu:: { Cpu , ThisThread } ,
8+ all_reduce:: common_cpu:: Operator as AllReduce ,
9+ common_cpu:: ThisThread ,
610 random_sample:: { KVPair , SampleArgs } ,
711 Blob ,
812} ;
13+ use regex:: Regex ;
14+ use std:: iter:: zip;
15+ use std:: ptr:: copy_nonoverlapping;
916use std:: slice:: from_raw_parts_mut;
10- use test_utils :: { Inference , TokenizerAndPrompt } ;
11-
12- type Worker < ' w > = Gpt2Worker < Operators , Weights < ' w > > ;
17+ use std :: sync :: { mpsc , Arc , Barrier } ;
18+ use std :: thread ;
19+ use test_utils :: { Inference , Task , TokenizerAndPrompt , WorkerSeed } ;
1320
21+ type Worker < ' w > = Gpt2Worker < Operators < InprocNode < usize > , AllReduce > , Weights < ' w > > ;
1422#[ test]
1523fn test_infer ( ) {
1624 let Some ( Inference {
1725 model,
26+ devices,
1827 prompt,
1928 as_user,
2029 temperature,
2130 top_p,
2231 top_k,
2332 max_steps,
24- ..
2533 } ) = Inference :: load ( )
2634 else {
2735 return ;
@@ -40,77 +48,138 @@ fn test_infer() {
4048 let sample_args = SampleArgs :: new ( temperature, top_p, top_k) . expect ( "invalid sample args" ) ;
4149 println ! ( "{sample_args:?}" ) ;
4250
43- let & Gpt2Meta {
44- dt_embd,
45- nctx,
46- nvoc,
47- d,
48- ..
49- } = & model. meta ;
50- let weights = Weights :: new ( & model) ;
51- let mut worker = Worker :: new ( 0 , & Cpu , model. meta . clone ( ) , weights) ;
52- let mut cache = model. meta . kv_cache ( nctx) . map ( Blob :: new) ;
53- let indices = RandomSample :: build_indices ( nvoc, & ThisThread ) ;
54- let sample = RandomSample :: new ( & Cpu ) ;
51+ let lens = devices
52+ . map ( |devices| {
53+ Regex :: new ( r"\d+" )
54+ . unwrap ( )
55+ . find_iter ( & devices)
56+ . map ( |c| c. as_str ( ) . parse ( ) . unwrap ( ) )
57+ . collect ( )
58+ } )
59+ . unwrap_or_else ( || vec ! [ 1 ] ) ;
60+ let dist = lens. iter ( ) . sum ( ) ;
61+ println ! ( "distribution: {lens:?}" ) ;
5562
56- test_utils:: test_infer ( eos, tokenizer, & prompt, max_steps, |input, pos| {
57- // 词汇编码缓存
58- let mut embd = Tensor :: new ( dt_embd, & [ input. len ( ) , d] ) . map ( Blob :: new) ;
59- // 词汇位置缓存
60- let mut logits = model. meta . logits ( 1 ) . map ( Blob :: new) ;
61- let l = embd. get ( ) . len ( ) / input. len ( ) ;
62- for ( i, & tok) in input. iter ( ) . enumerate ( ) {
63- embd. get_mut ( ) [ i * l..] [ ..l]
64- . copy_from_slice ( & model. token_embd [ tok as usize * l..] [ ..l] ) ;
65- }
66- worker
67- . launch (
68- gpt2:: args:: Args {
69- embd : embd. map_slice_mut ( ) ,
70- logits : logits. map_slice_mut ( ) ,
71- idx : postion ( input. len ( ) , pos) . map_slice ( ) ,
72- requests : vec ! [ gpt2:: args:: Request {
73- cache: cache. map_slice_mut( ) ,
74- seq_len: input. len( ) ,
75- out_len: 1 ,
76- pos,
77- } ] ,
78- max_seq_len : input. len ( ) ,
79- max_att_len : pos + input. len ( ) ,
80- } ,
81- & mut [ ] ,
82- & ThisThread ,
83- )
84- . unwrap ( ) ;
63+ let ( seeds, senders) = WorkerSeed :: new ( InprocNode :: new ( lens. len ( ) ) ) ;
64+ let barrier = Arc :: new ( Barrier :: new ( dist + 1 ) ) ;
65+ thread:: scope ( |s| {
66+ let _workers = zip ( lens, seeds)
67+ . enumerate ( )
68+ . scan ( 0 , |start, ( id, ( len, seed) ) | {
69+ let dist = Distribution :: new ( * start, len, dist) ;
70+ * start += len;
8571
86- let mut pair = KVPair :: new ( 0 , f16:: ZERO ) ;
87- let mut pairs = Tensor :: kv_pair_vec ( 1 , |_| unsafe {
88- from_raw_parts_mut ( & mut pair as * mut _ as _ , size_of_val ( & pair) )
89- } ) ;
72+ let meta = model. meta . distribute ( dist) ;
73+ let model = & model;
74+ let barrier = barrier. clone ( ) ;
75+ Some ( s. spawn ( move || {
76+ let WorkerSeed { node, tasks } = seed;
77+ let weights = Weights :: new ( model, dist) ;
78+ let mut worker = Worker :: new ( id, & node, meta. clone ( ) , weights) ;
79+ let mut cache = meta. kv_cache ( meta. nctx ) . map ( Blob :: new) ;
9080
91- sample
92- . launch (
93- & mut pairs,
94- & logits,
95- & indices,
96- sample_args,
97- & mut [ ] ,
98- & ThisThread ,
99- )
100- . unwrap ( ) ;
81+ let sample = RandomSample :: new ( & node) ;
82+ let indices = RandomSample :: build_indices ( model. meta . nvoc , & ThisThread ) ;
83+ let mut pair = KVPair :: new ( 0 , f16:: ZERO ) ;
84+ let mut pairs = Tensor :: kv_pair_vec ( 1 , |_| unsafe {
85+ from_raw_parts_mut ( & mut pair as * mut _ as * mut u8 , size_of_val ( & pair) )
86+ } ) ;
10187
102- pair. idx ( ) as _
103- } ) ;
88+ barrier. wait ( ) ;
89+ for task in tasks {
90+ let Task {
91+ nt,
92+ pos,
93+ embd,
94+ next,
95+ } = task;
96+ let mut embd = meta. embd ( nt) . map ( |size| {
97+ let mut blob = Blob :: new ( size) ;
98+ unsafe { copy_nonoverlapping ( embd, blob. as_mut_ptr ( ) , size) } ;
99+ blob
100+ } ) ;
101+ let mut logits = meta. logits ( if id == 0 { 1 } else { 0 } ) . map ( Blob :: new) ;
102+ worker
103+ . launch (
104+ gpt2:: args:: Args {
105+ embd : embd. map_slice_mut ( ) ,
106+ logits : logits. map_slice_mut ( ) ,
107+ idx : postion ( nt, pos) . map_slice ( ) ,
108+ requests : vec ! [ gpt2:: args:: Request {
109+ cache: cache. map_slice_mut( ) ,
110+ seq_len: nt,
111+ out_len: 1 ,
112+ pos,
113+ } ] ,
114+ max_seq_len : nt,
115+ max_att_len : pos + nt,
116+ } ,
117+ & mut [ ] ,
118+ & ThisThread ,
119+ )
120+ . unwrap ( ) ;
121+ if id == 0 {
122+ sample
123+ . launch (
124+ & mut pairs,
125+ & logits,
126+ & indices,
127+ sample_args,
128+ & mut [ ] ,
129+ & ThisThread ,
130+ )
131+ . unwrap ( ) ;
132+ next. send ( pair. idx ( ) as _ ) . unwrap ( )
133+ }
134+ }
135+ } ) )
136+ } )
137+ . collect :: < Vec < _ > > ( ) ;
138+
139+ let senders = senders. into_boxed_slice ( ) ;
140+ barrier. wait ( ) ;
141+ test_infer_par ( & model, senders, eos, tokenizer, & prompt, max_steps)
142+ } )
104143}
105144
145+ pub fn test_infer_par (
146+ model : & GPT2Storage < & [ u8 ] > ,
147+ senders : Box < [ mpsc:: Sender < Task > ] > ,
148+ eos : utok ,
149+ tokenizer : Tokenizer ,
150+ prompt : & str ,
151+ max_steps : usize ,
152+ ) {
153+ let ( next, next_recv) = mpsc:: channel ( ) ;
154+ test_utils:: test_infer ( eos, tokenizer, prompt, max_steps, |input, pos| {
155+ let mut embd = model. meta . embd ( input. len ( ) ) . map ( Blob :: new) . take ( ) ;
156+
157+ let d = embd. len ( ) / input. len ( ) ;
158+ for ( i, & tok) in input. iter ( ) . enumerate ( ) {
159+ embd[ i * d..] [ ..d] . copy_from_slice ( & model. token_embd [ tok as usize * d..] [ ..d] ) ;
160+ }
161+
162+ for sender in & senders {
163+ sender
164+ . send ( Task {
165+ nt : input. len ( ) ,
166+ pos,
167+ embd : embd. as_ptr ( ) ,
168+ next : next. clone ( ) ,
169+ } )
170+ . unwrap ( )
171+ }
172+ next_recv. recv ( ) . unwrap ( )
173+ } ) ;
174+ }
106175fn postion ( l : usize , pos : usize ) -> Tensor < Blob > {
107176 use gguf:: ggml_quants:: digit_layout:: types as ty;
108177 let mut ans = Tensor :: new ( ty:: U32 , & [ 1 , l] ) . map ( Blob :: new) ;
109178 let ( & mut [ ] , data, & mut [ ] ) = ( unsafe { ans. get_mut ( ) . align_to_mut :: < u32 > ( ) } ) else {
110179 panic ! ( )
111180 } ;
112- for i in 0 ..l {
113- data [ i ] = ( pos + i ) as u32 ;
114- }
181+ data . iter_mut ( )
182+ . enumerate ( )
183+ . for_each ( | ( i , item ) | * item = ( pos + i ) as u32 ) ;
115184 ans
116185}
0 commit comments