@@ -152,7 +152,7 @@ where
152152 // llama.cpp 定义死
153153 let scale_emb = 12f32 ;
154154 let scale_depth = 1.4f32 ;
155- // 提前进行缩放
155+ // 残差连接时权重缩放
156156 let s = scale_depth / ( nblk as f32 ) . sqrt ( ) ;
157157 fn ggml_scale ( embd : * mut f16 , s : f16 , l : usize ) {
158158 if l == 0 {
@@ -181,6 +181,11 @@ where
181181 let mut q = q. map ( |_| buf) ;
182182 let ( buf, workspace) = workspace. split_at_mut ( * kv_pe. get ( ) ) ;
183183 let mut kv_pe = kv_pe. map ( |_| buf) ;
184+ // 经行 attention
185+ let attn = tensor ( & [ nt, nh, dv] ) ;
186+ let ( buf, workspace) = workspace. split_at_mut ( * attn. get ( ) ) ;
187+ let mut attn = attn. map ( |_| buf) ;
188+
184189 let queue = queue_alloc. queue ( ) ;
185190 for iblk in 0 ..nblk {
186191 // norm
@@ -323,15 +328,11 @@ where
323328 let k_rope_2 = k_rope_0. tile ( 1 , & [ 1 , dh] ) . broadcast ( 1 , nh) ;
324329 self . rearrange ( & mut k_rope_r, & k_rope_2, workspace, queue_alloc) ?;
325330 self . rearrange ( & mut k_nope_r, & k_nope, workspace, queue_alloc) ?;
326- // 经行 attention
327- let attn = tensor ( & [ nt, nh, dv] ) ;
328- let ( buf, workspace) = workspace. split_at_mut ( * attn. get ( ) ) ;
329- let mut attn = attn. map ( |_| buf) ;
330331
331332 let mut q = q3. transpose ( & [ 1 , 0 ] ) ;
332333 let k = k. map_slice ( ) . transpose ( & [ 1 , 0 ] ) ;
333334 let mut v = v. map_slice_mut ( ) . transpose ( & [ 1 , 0 ] ) ;
334- let mut attn = attn. transpose ( & [ 1 , 0 ] ) ;
335+ let mut attn = unsafe { attn. map_slice_mut ( ) . transpose ( & [ 1 , 0 ] ) } ;
335336 self . attnention (
336337 & mut q,
337338 & k,
@@ -346,12 +347,9 @@ where
346347 let w = self . weights . attn_o ( iblk, queue) ;
347348
348349 self . mat_mul ( & mut x1, 0. , & o, & w, s, workspace, queue_alloc) ?;
350+ let inplace = unsafe { x. map_slice_static ( ) } ;
351+ self . add ( & mut x, & inplace, & x1, workspace, queue_alloc) ?;
349352 }
350- let inplace = unsafe { x. map_slice_static ( ) } ;
351- //是否给 add 加上缩放系数
352-
353- self . add ( & mut x, & inplace, & x1, workspace, queue_alloc) ?;
354-
355353 let w = self . weights . ffn_norm ( iblk, queue) ;
356354 self . rms_norm ( & mut x1, & x, & w, workspace, queue_alloc) ?;
357355 drop ( w) ;
@@ -361,29 +359,32 @@ where
361359 split ! ( gate_up => gate, up; [ di, di] @ 1 ) ;
362360 let mut gate = gate;
363361 let mut up = up;
364- let w = self . weights . ffn_gate ( iblk, queue) . transpose ( & [ 0 , 1 ] ) ;
362+ let w = self . weights . ffn_gate ( iblk, queue) ;
365363 self . mat_mul ( & mut gate, 0. , & x1, & w, 1. , workspace, queue_alloc) ?;
366- // Ops::debug(&w, queue);
364+
365+ let w = self . weights . ffn_up ( iblk, queue) ;
366+ self . mat_mul ( & mut up, 0. , & x1, & w, 1. , workspace, queue_alloc) ?;
367+
368+ self . swiglu ( & mut gate, & up, workspace, queue_alloc) ?;
367369
368370 fn print_first_10_elements ( ptr : * const f16 ) {
369371 assert ! ( !ptr. is_null( ) , "Pointer must not be null" ) ;
370372
371373 unsafe {
372374 for i in 0 ..10 {
373- // 逐个访问并打印前10个元素
375+ // 逐个访问并打印前 10 个元素
374376 let element = ptr. offset ( i as isize ) . read ( ) ;
375377 println ! ( "Element {}: {:?}" , i, element) ;
376378 }
377379 }
378380 }
379- print_first_10_elements ( w. base ( ) . cast :: < f16 > ( ) ) ;
380- todo ! ( ) ;
381-
382- self . swiglu ( & mut gate, & up, workspace, queue_alloc) ?;
383381
384382 let w = self . weights . ffn_down ( iblk, queue) ;
385- let residual = if self . id == 0 { 1. } else { 0. } ;
386- self . mat_mul ( & mut x, residual, & gate, & w, 1. , workspace, queue_alloc) ?;
383+ self . mat_mul ( & mut x1, 0. , & gate, & w, s, workspace, queue_alloc) ?;
384+
385+ let inplace = unsafe { x. map_slice_static ( ) } ;
386+ self . add ( & mut x, & inplace, & x1, workspace, queue_alloc) ?;
387+
387388 self . all_reduce ( & mut x, workspace, queue_alloc) ?
388389 }
389390 if logits. shape ( ) [ 0 ] == 0 {
@@ -808,7 +809,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
808809 iblk : usize ,
809810 queue : & ' a QueueOf < W :: Hardware > ,
810811 ) -> Tensor < W :: Weight < ' a > > {
811- const WHICH : MiniCPM3BlkWeight = MiniCPM3BlkWeight :: FfnGateUp ;
812+ const WHICH : MiniCPM3BlkWeight = MiniCPM3BlkWeight :: FfnGate ;
812813 let w = self . weights . load_blk ( WHICH , iblk, queue) ;
813814 self . ffn_gate . clone ( ) . map ( |_| w)
814815 }
@@ -818,7 +819,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
818819 iblk : usize ,
819820 queue : & ' a QueueOf < W :: Hardware > ,
820821 ) -> Tensor < W :: Weight < ' a > > {
821- const WHICH : MiniCPM3BlkWeight = MiniCPM3BlkWeight :: FfnGateUp ;
822+ const WHICH : MiniCPM3BlkWeight = MiniCPM3BlkWeight :: FfnUp ;
822823 let w = self . weights . load_blk ( WHICH , iblk, queue) ;
823824 self . ffn_up . clone ( ) . map ( |_| w)
824825 }
0 commit comments