11use crate :: { Operators , RandomSample , Weights } ;
2+ use common:: Distribution ;
23use gguf:: GGufModel ;
3- use gpt2:: { ext:: ggml_quants:: f16, Gpt2Meta , Gpt2Worker , Storage , Tensor } ;
4+ use gpt2:: { ext:: ggml_quants:: f16, GPT2Storage , Gpt2Worker , Tensor } ;
45use operators:: {
5- common_cpu:: { Cpu , ThisThread } ,
6+ all_reduce:: common_cpu:: Operator as AllReduce ,
7+ common_cpu:: { InprocNode , ThisThread } ,
68 random_sample:: { KVPair , SampleArgs } ,
79 Blob ,
810} ;
9- use std:: slice:: from_raw_parts_mut;
10- use test_utils:: { Inference , TokenizerAndPrompt } ;
11-
12- type Worker < ' w > = Gpt2Worker < Operators , Weights < ' w > > ;
11+ use regex:: Regex ;
12+ use std:: {
13+ iter:: zip,
14+ ptr:: copy_nonoverlapping,
15+ slice:: from_raw_parts_mut,
16+ sync:: { Arc , Barrier } ,
17+ thread,
18+ } ;
19+ use test_utils:: { test_infer_paralle, Inference , Task , TokenizerAndPrompt , WorkerSeed } ;
1320
21+ type Worker < ' w > = Gpt2Worker < Operators < InprocNode < usize > , AllReduce > , Weights < ' w > > ;
1422#[ test]
1523fn test_infer ( ) {
1624 let Some ( Inference {
1725 model,
26+ devices,
1827 prompt,
1928 as_user,
2029 temperature,
2130 top_p,
2231 top_k,
2332 max_steps,
24- ..
2533 } ) = Inference :: load ( )
2634 else {
2735 return ;
@@ -34,73 +42,104 @@ fn test_infer() {
3442 prompt,
3543 } = TokenizerAndPrompt :: new ( & gguf, prompt, as_user) ;
3644
37- let model = Storage :: from_gguf ( & gguf) ;
45+ let model = GPT2Storage :: from_gguf ( & gguf) ;
3846 println ! ( "{:?}" , model. meta) ;
3947
4048 let sample_args = SampleArgs :: new ( temperature, top_p, top_k) . expect ( "invalid sample args" ) ;
4149 println ! ( "{sample_args:?}" ) ;
4250
43- let & Gpt2Meta {
44- dt_embd,
45- nctx,
46- nvoc,
47- d,
48- ..
49- } = & model. meta ;
50- let weights = Weights :: new ( & model) ;
51- let mut worker = Worker :: new ( & Cpu , model. meta . clone ( ) , weights) ;
52- let mut cache = model. meta . kv_cache ( nctx) . map ( Blob :: new) ;
53- let indices = RandomSample :: build_indices ( nvoc, & ThisThread ) ;
54- let sample = RandomSample :: new ( & Cpu ) ;
51+ let lens = devices
52+ . map ( |devices| {
53+ Regex :: new ( r"\d+" )
54+ . unwrap ( )
55+ . find_iter ( & devices)
56+ . map ( |c| c. as_str ( ) . parse ( ) . unwrap ( ) )
57+ . collect ( )
58+ } )
59+ . unwrap_or_else ( || vec ! [ 1 ] ) ;
60+ let dist = lens. iter ( ) . sum ( ) ;
61+ println ! ( "distribution: {lens:?}" ) ;
62+
63+ let ( seeds, senders) = WorkerSeed :: new ( InprocNode :: new ( lens. len ( ) ) ) ;
64+ let barrier = Arc :: new ( Barrier :: new ( dist + 1 ) ) ;
65+ thread:: scope ( |s| {
66+ let _workers = zip ( lens, seeds)
67+ . enumerate ( )
68+ . scan ( 0 , |start, ( id, ( len, seed) ) | {
69+ let dist = Distribution :: new ( * start, len, dist) ;
70+ * start += len;
5571
56- test_utils:: test_infer ( eos, tokenizer, & prompt, max_steps, |input, pos| {
57- // 词汇编码缓存
58- let mut embd = Tensor :: new ( dt_embd, & [ input. len ( ) , d] ) . map ( Blob :: new) ;
59- // 词汇位置缓存
60- let mut logits = model. meta . logits ( 1 ) . map ( Blob :: new) ;
61- let l = embd. get ( ) . len ( ) / input. len ( ) ;
62- for ( i, & tok) in input. iter ( ) . enumerate ( ) {
63- embd. get_mut ( ) [ i * l..] [ ..l]
64- . copy_from_slice ( & model. token_embd [ tok as usize * l..] [ ..l] ) ;
65- }
66- worker
67- . launch (
68- gpt2:: args:: Args {
69- embd : embd. map_slice_mut ( ) ,
70- logits : logits. map_slice_mut ( ) ,
71- idx : postion ( input. len ( ) , pos) . map_slice ( ) ,
72- requests : vec ! [ gpt2:: args:: Request {
73- cache: cache. map_slice_mut( ) ,
74- seq_len: input. len( ) ,
75- out_len: 1 ,
76- pos,
77- } ] ,
78- max_seq_len : input. len ( ) ,
79- max_att_len : pos + input. len ( ) ,
80- } ,
81- & mut [ ] ,
82- & ThisThread ,
83- )
84- . unwrap ( ) ;
72+ let meta = model. meta . distribute ( dist) ;
73+ let model = & model;
74+ let barrier = barrier. clone ( ) ;
75+ Some ( s. spawn ( move || {
76+ let WorkerSeed { node, tasks } = seed;
77+ let weights = Weights :: new ( model, dist) ;
78+ let mut worker = Worker :: new ( id, & node, meta. clone ( ) , weights) ;
79+ let mut cache = meta. kv_cache ( meta. nctx ) . map ( Blob :: new) ;
8580
86- let mut pair = KVPair :: new ( 0 , f16:: ZERO ) ;
87- let mut pairs = Tensor :: kv_pair_vec ( 1 , |_| unsafe {
88- from_raw_parts_mut ( & mut pair as * mut _ as _ , size_of_val ( & pair) )
89- } ) ;
81+ let sample = RandomSample :: new ( & node) ;
82+ let indices = RandomSample :: build_indices ( model. meta . nvoc , & ThisThread ) ;
83+ let mut pair = KVPair :: new ( 0 , f16:: ZERO ) ;
84+ let mut pairs = Tensor :: kv_pair_vec ( 1 , |_| unsafe {
85+ from_raw_parts_mut ( & mut pair as * mut _ as * mut u8 , size_of_val ( & pair) )
86+ } ) ;
9087
91- sample
92- . launch (
93- & mut pairs,
94- & logits,
95- & indices,
96- sample_args,
97- & mut [ ] ,
98- & ThisThread ,
99- )
100- . unwrap ( ) ;
88+ barrier. wait ( ) ;
89+ for task in tasks {
90+ let Task {
91+ nt,
92+ pos,
93+ embd,
94+ next,
95+ } = task;
96+ let mut embd = meta. embd ( nt) . map ( |size| {
97+ let mut blob = Blob :: new ( size) ;
98+ unsafe { copy_nonoverlapping ( embd, blob. as_mut_ptr ( ) , size) } ;
99+ blob
100+ } ) ;
101+ let mut logits = meta. logits ( if id == 0 { 1 } else { 0 } ) . map ( Blob :: new) ;
102+ worker
103+ . launch (
104+ gpt2:: args:: Args {
105+ embd : embd. map_slice_mut ( ) ,
106+ logits : logits. map_slice_mut ( ) ,
107+ idx : postion ( nt, pos) . map_slice ( ) ,
108+ requests : vec ! [ gpt2:: args:: Request {
109+ cache: cache. map_slice_mut( ) ,
110+ seq_len: nt,
111+ out_len: 1 ,
112+ pos,
113+ } ] ,
114+ max_seq_len : nt,
115+ max_att_len : pos + nt,
116+ } ,
117+ & mut [ ] ,
118+ & ThisThread ,
119+ )
120+ . unwrap ( ) ;
121+ if id == 0 {
122+ sample
123+ . launch (
124+ & mut pairs,
125+ & logits,
126+ & indices,
127+ sample_args,
128+ & mut [ ] ,
129+ & ThisThread ,
130+ )
131+ . unwrap ( ) ;
132+ next. send ( pair. idx ( ) as _ ) . unwrap ( )
133+ }
134+ }
135+ } ) )
136+ } )
137+ . collect :: < Vec < _ > > ( ) ;
101138
102- pair. idx ( ) as _
103- } ) ;
139+ let senders = senders. into_boxed_slice ( ) ;
140+ barrier. wait ( ) ;
141+ test_infer_paralle ( & model, senders, eos, tokenizer, & prompt, max_steps)
142+ } )
104143}
105144
106145fn postion ( l : usize , pos : usize ) -> Tensor < Blob > {
@@ -109,8 +148,8 @@ fn postion(l: usize, pos: usize) -> Tensor<Blob> {
109148 let ( & mut [ ] , data, & mut [ ] ) = ( unsafe { ans. get_mut ( ) . align_to_mut :: < u32 > ( ) } ) else {
110149 panic ! ( )
111150 } ;
112- for i in 0 ..l {
113- data [ i ] = ( pos + i ) as u32 ;
114- }
151+ data . iter_mut ( )
152+ . enumerate ( )
153+ . for_each ( | ( i , item ) | * item = ( pos + i ) as u32 ) ;
115154 ans
116155}
0 commit comments