1- use clip:: { ClipStorage , WeightLoader } ;
2- use operators:: { common_cpu:: Cpu , conv, QueueOf , TopoNode } ;
3- use std:: marker:: PhantomData ;
1+ use clip:: { BlkWeight , ClipBlkStorage , ClipStorage , Tensor , WeightLoader } ;
2+ use operators:: { common_cpu:: Cpu , conv, ByteOf , QueueOf , TopoNode } ;
3+ use std:: { marker:: PhantomData , ops :: Deref } ;
44
55pub struct Operators < N = Cpu > ( PhantomData < N > ) ;
66
2121 type TopoNode = Cpu ;
2222 type Conv = conv:: common_cpu:: ConvIm2Col ;
2323 type AddRows = op ! ( add_rows) ;
24+ type Rearrange = op ! ( rearrange) ;
2425 type LayerNorm = op ! ( layer_norm) ;
26+ type MatMul = op ! ( mat_mul) ;
27+
28+ fn debug < T > ( tensor : & Tensor < T > )
29+ where
30+ T : Deref < Target = [ ByteOf < Self :: Hardware > ] > ,
31+ {
32+ println ! ( "{tensor}" )
33+ }
2534}
2635
2736impl < ' w > Weights < ' w > {
@@ -32,37 +41,67 @@ impl<'w> Weights<'w> {
3241
3342impl WeightLoader for Weights < ' _ > {
3443 type Hardware = Cpu ;
35- type Weight < ' s >
44+ type Memory < ' s >
3645 = & ' s [ u8 ]
3746 where
3847 Self : ' s ;
3948
49+ fn load_blk (
50+ & self ,
51+ which : BlkWeight ,
52+ iblk : usize ,
53+ _queue : & QueueOf < Self :: Hardware > ,
54+ ) -> [ Self :: Memory < ' _ > ; 2 ] {
55+ let ClipBlkStorage {
56+ attn_norm_w,
57+ attn_norm_b,
58+ attn_qkv_w,
59+ attn_qkv_b,
60+ attn_o_w,
61+ attn_o_b,
62+ ffn_norm_w,
63+ ffn_norm_b,
64+ ffn_up_w,
65+ ffn_up_b,
66+ ffn_down_w,
67+ ffn_down_b,
68+ } = & self . 0 . blocks [ iblk] ;
69+ match which {
70+ BlkWeight :: AttnNorm => [ attn_norm_w, attn_norm_b] ,
71+ BlkWeight :: AttnQKV => [ attn_qkv_w, attn_qkv_b] ,
72+ BlkWeight :: AttnO => [ attn_o_w, attn_o_b] ,
73+ BlkWeight :: FfnNorm => [ ffn_norm_w, ffn_norm_b] ,
74+ BlkWeight :: FfnUp => [ ffn_up_w, ffn_up_b] ,
75+ BlkWeight :: FfnDown => [ ffn_down_w, ffn_down_b] ,
76+ }
77+ }
78+
4079 #[ inline]
41- fn patch_embd < ' a > ( & ' a self , _queue : & ' a QueueOf < Self :: Hardware > ) -> [ Self :: Weight < ' a > ; 2 ] {
80+ fn patch_embd < ' a > ( & ' a self , _queue : & ' a QueueOf < Self :: Hardware > ) -> [ Self :: Memory < ' a > ; 2 ] {
4281 [ self . 0 . patch_embd_w , self . 0 . patch_embd_b ]
4382 }
4483
4584 #[ inline]
46- fn pos_embd < ' a > ( & ' a self , _queue : & ' a QueueOf < Self :: Hardware > ) -> Self :: Weight < ' a > {
85+ fn pos_embd < ' a > ( & ' a self , _queue : & ' a QueueOf < Self :: Hardware > ) -> Self :: Memory < ' a > {
4786 self . 0 . pos_embd
4887 }
4988
5089 #[ inline]
5190 fn pre_norm < ' a > (
5291 & ' a self ,
5392 _queue : & ' a QueueOf < Self :: Hardware > ,
54- ) -> Option < [ Self :: Weight < ' a > ; 2 ] > {
93+ ) -> Option < [ Self :: Memory < ' a > ; 2 ] > {
5594 self . 0 . pre_norm
5695 }
5796
5897 #[ inline]
5998 fn post_norm < ' a > (
6099 & ' a self ,
61100 _queue : & ' a QueueOf < Self :: Hardware > ,
62- ) -> Option < [ Self :: Weight < ' a > ; 2 ] > {
101+ ) -> Option < [ Self :: Memory < ' a > ; 2 ] > {
63102 self . 0 . post_norm
64103 }
65104}
66105
67106#[ cfg( test) ]
68- mod test_infer ;
107+ mod infer ;
0 commit comments