@@ -2,6 +2,7 @@ use super::{args::Args, ClipMeta};
22use operators:: {
33 add_rows:: { self , AddRows } ,
44 conv:: { self , Conv } ,
5+ layer_norm:: { self , LayerNorm } ,
56 ByteOf , Hardware , LaunchError , Operator , QueueAlloc , QueueOf , TopoNode ,
67} ;
78use std:: {
@@ -15,6 +16,7 @@ pub trait Operators {
1516 type TopoNode : TopoNode < Self :: Hardware > ;
1617 type Conv : Conv < Self :: Hardware > ;
1718 type AddRows : AddRows < Self :: Hardware > ;
19+ type LayerNorm : LayerNorm < Self :: Hardware > ;
1820}
1921
2022pub trait WeightLoader {
@@ -25,13 +27,17 @@ pub trait WeightLoader {
2527
2628 fn patch_embd < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> [ Self :: Weight < ' a > ; 2 ] ;
2729 fn pos_embd < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> Self :: Weight < ' a > ;
30+ fn pre_norm < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> Option < [ Self :: Weight < ' a > ; 2 ] > ;
31+ fn post_norm < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > )
32+ -> Option < [ Self :: Weight < ' a > ; 2 ] > ;
2833}
2934
3035pub struct ClipWorker < Ops : Operators , W > {
3136 meta : ClipMeta ,
3237 weights : WeightDecorator < W > ,
3338 conv : Ops :: Conv ,
3439 add_rows : Ops :: AddRows ,
40+ layer_norm : Ops :: LayerNorm ,
3541 pub debug : bool ,
3642}
3743
@@ -43,6 +49,7 @@ impl<Ops: Operators, W> ClipWorker<Ops, W> {
4349 meta,
4450 conv : Ops :: Conv :: new ( processor) ,
4551 add_rows : Ops :: AddRows :: new ( processor) ,
52+ layer_norm : Ops :: LayerNorm :: new ( processor) ,
4653 debug : true ,
4754 }
4855 }
9097 let pos_embd = self . weights . pos_embd ( queue) ;
9198 self . add_rows ( & mut embd, & pos_embd, & pos, workspace, queue_alloc) ?;
9299
100+ if let Some ( [ scale, bias] ) = self . weights . pre_norm ( queue) {
101+ let inplace = unsafe { embd. map_slice_static ( ) } ;
102+ self . layer_norm ( & mut embd, & inplace, & scale, & bias, workspace, queue_alloc) ?;
103+ }
104+
105+ for _ in 0 ..self . meta . nblk { }
106+
107+ if let Some ( [ scale, bias] ) = self . weights . post_norm ( queue) {
108+ let inplace = unsafe { embd. map_slice_static ( ) } ;
109+ self . layer_norm ( & mut embd, & inplace, & scale, & bias, workspace, queue_alloc) ?;
110+ }
111+
93112 if self . debug {
94113 println ! ( "encode {n} x {h} x {w} image in {:?}" , time. elapsed( ) ) ;
95114 }
@@ -166,13 +185,47 @@ where
166185 queue_alloc,
167186 )
168187 }
188+
189+ fn layer_norm < Y , X , Scale , Bias , QA > (
190+ & self ,
191+ y : & mut Tensor < Y > ,
192+ x : & Tensor < X > ,
193+ scale : & Tensor < Scale > ,
194+ bias : & Tensor < Bias > ,
195+ workspace : & mut [ ByteOf < Ops :: Hardware > ] ,
196+ queue_alloc : & QA ,
197+ ) -> Result < ( ) , LaunchError >
198+ where
199+ Y : DerefMut < Target = [ ByteOf < Ops :: Hardware > ] > ,
200+ X : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
201+ Scale : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
202+ Bias : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
203+ QA : QueueAlloc < Hardware = Ops :: Hardware > ,
204+ {
205+ self . layer_norm . launch (
206+ & layer_norm:: Args {
207+ y_layout : y. layout ( ) ,
208+ y_base : y. base_mut ( ) ,
209+ x_layout : x. layout ( ) ,
210+ x_base : x. base ( ) ,
211+ scale_layout : scale. layout ( ) ,
212+ scale_base : scale. base ( ) ,
213+ bias_layout : bias. layout ( ) ,
214+ bias_base : bias. base ( ) ,
215+ epsilon : self . meta . epsilon ,
216+ } ,
217+ workspace,
218+ queue_alloc,
219+ )
220+ }
169221}
170222
171223struct WeightDecorator < W > {
172224 weights : W ,
173225 patch_embd_w : Tensor < usize > ,
174226 patch_embd_b : Tensor < usize > ,
175227 pos_embd : Tensor < usize > ,
228+ norm : Tensor < usize > ,
176229}
177230
178231impl ClipMeta {
@@ -181,6 +234,7 @@ impl ClipMeta {
181234 patch_embd_w : self . patch_embd_w ( ) ,
182235 patch_embd_b : self . patch_embd_b ( ) ,
183236 pos_embd : self . pos_embd ( ) ,
237+ norm : self . norm ( ) ,
184238 weights,
185239 }
186240 }
@@ -201,4 +255,24 @@ impl<W: WeightLoader> WeightDecorator<W> {
201255 let pos_embd = self . weights . pos_embd ( queue) ;
202256 self . pos_embd . clone ( ) . map ( |_| pos_embd)
203257 }
258+
259+ #[ inline]
260+ pub fn pre_norm < ' a > (
261+ & ' a self ,
262+ queue : & ' a QueueOf < W :: Hardware > ,
263+ ) -> Option < [ Tensor < W :: Weight < ' a > > ; 2 ] > {
264+ self . weights
265+ . pre_norm ( queue)
266+ . map ( |pair| pair. map ( |w| self . norm . clone ( ) . map ( |_| w) ) )
267+ }
268+
269+ #[ inline]
270+ pub fn post_norm < ' a > (
271+ & ' a self ,
272+ queue : & ' a QueueOf < W :: Hardware > ,
273+ ) -> Option < [ Tensor < W :: Weight < ' a > > ; 2 ] > {
274+ self . weights
275+ . post_norm ( queue)
276+ . map ( |pair| pair. map ( |w| self . norm . clone ( ) . map ( |_| w) ) )
277+ }
204278}
0 commit comments