11#![ cfg( detected) ]
22
3- use common:: { Distribution , WeightMemCalculator } ;
3+ use common:: Distribution ;
44use llama:: { LlamaBlkStorage , LlamaBlkWeight , LlamaStorage , Tensor , WeightLoader } ;
55use operators:: {
66 all_reduce:: { AllReduce , NonAllReduce } ,
@@ -10,12 +10,7 @@ use operators::{
1010 rearrange:: opencl:: Operator as Rearrange ,
1111 Blob , ByteOf , QueueOf , TopoNode ,
1212} ;
13- use std:: {
14- iter:: zip,
15- marker:: PhantomData ,
16- ops:: { Deref , Range } ,
17- ptr:: copy_nonoverlapping,
18- } ;
13+ use std:: { marker:: PhantomData , ops:: Deref , ptr:: copy_nonoverlapping} ;
1914
2015pub struct Operators < N = ClDevice , R = NonAllReduce < ClDevice , Rearrange > > ( PhantomData < ( N , R ) > ) ;
2116
6560
6661pub struct Weights {
6762 nexp : usize ,
68- mem : SvmBlob ,
69- blks : Box < [ LlamaBlkStorage < Range < usize > > ] > ,
70- output_norm : Range < usize > ,
71- output : Range < usize > ,
63+ blks : Box < [ LlamaBlkStorage < SvmBlob > ] > ,
64+ output_norm : SvmBlob ,
65+ output : SvmBlob ,
7266}
7367
7468impl Weights {
@@ -81,52 +75,40 @@ impl Weights {
8175 ..
8276 } = model;
8377
84- let mut calculator = WeightMemCalculator :: new ( size_of :: < usize > ( ) ) ;
85- let meta_dist = meta. distribute ( dist) ;
86- let blk_size = meta_dist. blk ( ) ;
87- let off_blks = ( 0 ..meta_dist. nblk )
88- . map ( |_| {
89- blk_size
90- . clone ( )
78+ let meta = meta. distribute ( dist) ;
79+ let queue = ctx. queue ( ) ;
80+ let blks = blocks
81+ . iter ( )
82+ . map ( |blk| {
83+ blk. clone ( )
9184 . into_vec ( )
9285 . into_iter ( )
93- . map ( |( which, size) | ( which, calculator. push ( size) ) )
86+ . map ( |( which, data) | {
87+ let blob = meta. distribute_data ( which, data, dist, Blob :: new) ;
88+ let mut svm = ctx. malloc :: < u8 > ( blob. len ( ) ) ;
89+ let mut map = queue. map_mut ( & mut svm, false ) ;
90+ map. copy_from_slice ( & blob) ;
91+ queue. unmap ( map) ;
92+ ( which, svm)
93+ } )
9494 . collect :: < LlamaBlkStorage < _ > > ( )
9595 } )
9696 . collect :: < Vec < _ > > ( ) ;
97- let off_output_norm = calculator. push ( output_norm. len ( ) ) ;
98- let off_output = calculator. push ( output. len ( ) ) ;
9997
100- let mut mem = ctx. malloc :: < u8 > ( calculator. size ( ) ) ;
101- let queue = ctx. queue ( ) ;
102-
103- for ( blk, off) in zip ( blocks, off_blks. clone ( ) ) {
104- let blk = blk. clone ( ) . into_vec ( ) ;
105- let off = off. into_vec ( ) ;
106- for ( ( which, data) , ( which_, off) ) in zip ( blk, off) {
107- assert_eq ! ( which, which_) ;
108- if off. is_empty ( ) {
109- continue ;
110- }
111- let data = meta. distribute_data ( which, data, dist, Blob :: new) ;
112- let mut map = queue. map_mut ( & mut mem[ off] , false ) ;
113- map. copy_from_slice ( & data) ;
114- queue. unmap ( map)
115- }
116- }
117- let mut map = queue. map_mut ( & mut mem[ off_output_norm. clone ( ) ] , false ) ;
118- map. copy_from_slice ( output_norm) ;
119- queue. unmap ( map) ;
120- let mut map = queue. map_mut ( & mut mem[ off_output. clone ( ) ] , false ) ;
121- map. copy_from_slice ( output) ;
122- queue. unmap ( map) ;
98+ let mut output_norm_svm = ctx. malloc :: < u8 > ( output_norm. len ( ) ) ;
99+ let mut output_svm = ctx. malloc :: < u8 > ( output. len ( ) ) ;
100+ let mut output_norm_map = queue. map_mut ( & mut output_norm_svm, false ) ;
101+ let mut output_map = queue. map_mut ( & mut output_svm, false ) ;
102+ output_norm_map. copy_from_slice ( output_norm) ;
103+ output_map. copy_from_slice ( output) ;
104+ queue. unmap ( output_norm_map) ;
105+ queue. unmap ( output_map) ;
123106
124107 Self {
125108 nexp : meta. nexp ,
126- mem,
127- blks : off_blks. into_boxed_slice ( ) ,
128- output_norm : off_output_norm,
129- output : off_output,
109+ blks : blks. into_boxed_slice ( ) ,
110+ output_norm : output_norm_svm,
111+ output : output_svm,
130112 }
131113 }
132114}
@@ -158,7 +140,7 @@ impl WeightLoader for Weights {
158140
159141 use LlamaBlkWeight as W ;
160142 #[ rustfmt:: skip]
161- let range = match which {
143+ let ans = match which {
162144 W :: AttnNorm => attn_norm ,
163145 W :: AttnQKV => attn_qkv ,
164146 W :: AttnQKVBias => attn_qkv_bias,
@@ -168,7 +150,7 @@ impl WeightLoader for Weights {
168150 W :: FfnGateUp => ffn_gate_up ,
169151 W :: FfnDown => ffn_down ,
170152 } ;
171- & self . mem [ range . clone ( ) ]
153+ ans
172154 }
173155
174156 fn load_moe < ' a > (
@@ -184,26 +166,25 @@ impl WeightLoader for Weights {
184166 ..
185167 } = & self . blks [ iblk] ;
186168
187- let range = match which {
169+ let w = match which {
188170 LlamaBlkWeight :: FfnGateUp => ffn_gate_up,
189171 LlamaBlkWeight :: FfnDown => ffn_down,
190172 _ => unreachable ! ( ) ,
191173 } ;
192- let w = & self . mem [ range. clone ( ) ] ;
193174 let one = w. len ( ) / self . nexp ;
194175 & w[ iexp * one..] [ ..one]
195176 }
196177
197178 #[ inline]
198179 fn output_norm ( & self , _queue : & QueueOf < Self :: Hardware > ) -> Self :: Weight < ' _ > {
199- & self . mem [ self . output_norm . clone ( ) ]
180+ & self . output_norm
200181 }
201182
202183 #[ inline]
203184 fn output ( & self , _queue : & QueueOf < Self :: Hardware > ) -> Self :: Weight < ' _ > {
204- & self . mem [ self . output . clone ( ) ]
185+ & self . output
205186 }
206187}
207188
208189#[ cfg( test) ]
209- mod infer;
190+ mod infer;
0 commit comments