1- use clip:: { ClipStorage , WeightLoader } ;
2- use operators:: { common_cpu:: Cpu , conv, QueueOf , TopoNode } ;
3- use std:: marker:: PhantomData ;
1+ use clip:: { BlkWeight , ClipBlkStorage , ClipStorage , Tensor , WeightLoader } ;
2+ use operators:: { common_cpu:: Cpu , conv, ByteOf , QueueOf , TopoNode } ;
3+ use std:: { marker:: PhantomData , ops :: Deref } ;
44
55pub struct Operators < N = Cpu > ( PhantomData < N > ) ;
66
2222 type Conv = conv:: common_cpu:: ConvIm2Col ;
2323 type AddRows = op ! ( add_rows) ;
2424 type LayerNorm = op ! ( layer_norm) ;
25+ type MatMul = op ! ( mat_mul) ;
26+ type Attention = op ! ( attention) ;
27+ type Gelu = op ! ( gelu) ;
28+ type Add = op ! ( add) ;
29+ type Rearrange = op ! ( rearrange) ;
30+
31+ fn debug < T > ( tensor : & Tensor < T > )
32+ where
33+ T : Deref < Target = [ ByteOf < Self :: Hardware > ] > ,
34+ {
35+ println ! ( "{tensor}" )
36+ }
2537}
2638
2739impl < ' w > Weights < ' w > {
@@ -32,37 +44,67 @@ impl<'w> Weights<'w> {
3244
3345impl WeightLoader for Weights < ' _ > {
3446 type Hardware = Cpu ;
35- type Weight < ' s >
47+ type Memory < ' s >
3648 = & ' s [ u8 ]
3749 where
3850 Self : ' s ;
3951
52+ fn load_blk (
53+ & self ,
54+ which : BlkWeight ,
55+ iblk : usize ,
56+ _queue : & QueueOf < Self :: Hardware > ,
57+ ) -> [ Self :: Memory < ' _ > ; 2 ] {
58+ let ClipBlkStorage {
59+ attn_norm_w,
60+ attn_norm_b,
61+ attn_qkv_w,
62+ attn_qkv_b,
63+ attn_o_w,
64+ attn_o_b,
65+ ffn_norm_w,
66+ ffn_norm_b,
67+ ffn_up_w,
68+ ffn_up_b,
69+ ffn_down_w,
70+ ffn_down_b,
71+ } = & self . 0 . blocks [ iblk] ;
72+ match which {
73+ BlkWeight :: AttnNorm => [ attn_norm_w, attn_norm_b] ,
74+ BlkWeight :: AttnQKV => [ attn_qkv_w, attn_qkv_b] ,
75+ BlkWeight :: AttnO => [ attn_o_w, attn_o_b] ,
76+ BlkWeight :: FfnNorm => [ ffn_norm_w, ffn_norm_b] ,
77+ BlkWeight :: FfnUp => [ ffn_up_w, ffn_up_b] ,
78+ BlkWeight :: FfnDown => [ ffn_down_w, ffn_down_b] ,
79+ }
80+ }
81+
4082 #[ inline]
41- fn patch_embd < ' a > ( & ' a self , _queue : & ' a QueueOf < Self :: Hardware > ) -> [ Self :: Weight < ' a > ; 2 ] {
83+ fn patch_embd < ' a > ( & ' a self , _queue : & ' a QueueOf < Self :: Hardware > ) -> [ Self :: Memory < ' a > ; 2 ] {
4284 [ self . 0 . patch_embd_w , self . 0 . patch_embd_b ]
4385 }
4486
4587 #[ inline]
46- fn pos_embd < ' a > ( & ' a self , _queue : & ' a QueueOf < Self :: Hardware > ) -> Self :: Weight < ' a > {
88+ fn pos_embd < ' a > ( & ' a self , _queue : & ' a QueueOf < Self :: Hardware > ) -> Self :: Memory < ' a > {
4789 self . 0 . pos_embd
4890 }
4991
5092 #[ inline]
5193 fn pre_norm < ' a > (
5294 & ' a self ,
5395 _queue : & ' a QueueOf < Self :: Hardware > ,
54- ) -> Option < [ Self :: Weight < ' a > ; 2 ] > {
96+ ) -> Option < [ Self :: Memory < ' a > ; 2 ] > {
5597 self . 0 . pre_norm
5698 }
5799
58100 #[ inline]
59101 fn post_norm < ' a > (
60102 & ' a self ,
61103 _queue : & ' a QueueOf < Self :: Hardware > ,
62- ) -> Option < [ Self :: Weight < ' a > ; 2 ] > {
104+ ) -> Option < [ Self :: Memory < ' a > ; 2 ] > {
63105 self . 0 . post_norm
64106 }
65107}
66108
67109#[ cfg( test) ]
68- mod test_infer ;
110+ mod infer ;
0 commit comments