11use super :: { args:: Args , ClipMeta } ;
22use operators:: {
3+ add_rows:: { self , AddRows } ,
34 conv:: { self , Conv } ,
45 ByteOf , Hardware , LaunchError , Operator , QueueAlloc , QueueOf , TopoNode ,
56} ;
@@ -13,6 +14,7 @@ pub trait Operators {
1314 type Hardware : Hardware ;
1415 type TopoNode : TopoNode < Self :: Hardware > ;
1516 type Conv : Conv < Self :: Hardware > ;
17+ type AddRows : AddRows < Self :: Hardware > ;
1618}
1719
1820pub trait WeightLoader {
@@ -22,12 +24,14 @@ pub trait WeightLoader {
2224 Self : ' s ;
2325
2426 fn patch_embd < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> [ Self :: Weight < ' a > ; 2 ] ;
27+ fn pos_embd < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> Self :: Weight < ' a > ;
2528}
2629
2730pub struct ClipWorker < Ops : Operators , W > {
2831 meta : ClipMeta ,
2932 weights : WeightDecorator < W > ,
3033 conv : Ops :: Conv ,
34+ add_rows : Ops :: AddRows ,
3135 pub debug : bool ,
3236}
3337
@@ -38,6 +42,7 @@ impl<Ops: Operators, W> ClipWorker<Ops, W> {
3842 weights : meta. decorator ( weights) ,
3943 meta,
4044 conv : Ops :: Conv :: new ( processor) ,
45+ add_rows : Ops :: AddRows :: new ( processor) ,
4146 debug : true ,
4247 }
4348 }
6469 QA : QueueAlloc < Hardware = Ops :: Hardware > ,
6570 {
6671 let time = Instant :: now ( ) ;
67- let Args { raw, .. } = args;
72+ let Args { raw, pos } = args;
6873 let queue = queue_alloc. queue ( ) ;
6974
7075 let ClipMeta { dt_embd, .. } = self . meta ;
8085 let mut embd = Tensor :: new ( dt_embd, & [ n, m, h / hk, w / wk] ) . map ( |s| queue_alloc. alloc ( s) ) ;
8186 self . conv ( & mut embd, & raw , & k, & b, workspace, queue_alloc) ?;
8287
83- let _embd = embd. merge ( 2 ..4 ) . unwrap ( ) . transpose ( & [ 2 , 1 ] ) ;
88+ let mut embd = embd. merge ( 2 ..4 ) . unwrap ( ) . transpose ( & [ 2 , 1 ] ) ;
89+
90+ let pos_embd = self . weights . pos_embd ( queue) ;
91+ self . add_rows ( & mut embd, & pos_embd, & pos, workspace, queue_alloc) ?;
8492
8593 if self . debug {
8694 println ! ( "encode {n} x {h} x {w} image in {:?}" , time. elapsed( ) ) ;
@@ -130,19 +138,49 @@ where
130138 queue_alloc,
131139 )
132140 }
141+
142+ fn add_rows < Dst , Src , Idx , QA > (
143+ & self ,
144+ dst : & mut Tensor < Dst > ,
145+ src : & Tensor < Src > ,
146+ idx : & Tensor < Idx > ,
147+ workspace : & mut [ ByteOf < Ops :: Hardware > ] ,
148+ queue_alloc : & QA ,
149+ ) -> Result < ( ) , LaunchError >
150+ where
151+ Dst : DerefMut < Target = [ ByteOf < Ops :: Hardware > ] > ,
152+ Src : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
153+ Idx : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
154+ QA : QueueAlloc < Hardware = Ops :: Hardware > ,
155+ {
156+ self . add_rows . launch (
157+ & add_rows:: Args {
158+ dst_layout : dst. layout ( ) ,
159+ dst_base : dst. base_mut ( ) ,
160+ src_layout : src. layout ( ) ,
161+ src_base : src. base ( ) ,
162+ idx_layout : idx. layout ( ) ,
163+ idx_base : idx. base ( ) ,
164+ } ,
165+ workspace,
166+ queue_alloc,
167+ )
168+ }
133169}
134170
135171struct WeightDecorator < W > {
136172 weights : W ,
137173 patch_embd_w : Tensor < usize > ,
138174 patch_embd_b : Tensor < usize > ,
175+ pos_embd : Tensor < usize > ,
139176}
140177
141178impl ClipMeta {
142179 fn decorator < W > ( & self , weights : W ) -> WeightDecorator < W > {
143180 WeightDecorator {
144181 patch_embd_w : self . patch_embd_w ( ) ,
145182 patch_embd_b : self . patch_embd_b ( ) ,
183+ pos_embd : self . pos_embd ( ) ,
146184 weights,
147185 }
148186 }
@@ -157,4 +195,10 @@ impl<W: WeightLoader> WeightDecorator<W> {
157195 self . patch_embd_b . clone ( ) . map ( |_| b) ,
158196 ]
159197 }
198+
199+ #[ inline]
200+ pub fn pos_embd < ' a > ( & ' a self , queue : & ' a QueueOf < W :: Hardware > ) -> Tensor < W :: Weight < ' a > > {
201+ let pos_embd = self . weights . pos_embd ( queue) ;
202+ self . pos_embd . clone ( ) . map ( |_| pos_embd)
203+ }
160204}
0 commit comments