@@ -3,6 +3,8 @@ use gguf::ggml_quants::digit_layout::types as ty;
33use gguf:: ggml_quants:: digit_layout:: DigitLayout ;
44use half:: f16;
55use itertools:: { izip, Itertools } ;
6+ use operators:: scale;
7+ use operators:: scale:: Scale ;
68use operators:: {
79 add:: { self , Add } ,
810 all_reduce:: { self , AllReduce , ReduceOp } ,
@@ -30,6 +32,7 @@ pub trait Operators {
3032 type Add : Add < Self :: Hardware > ;
3133 type MatMul : MatMul < Self :: Hardware > ;
3234 type Swiglu : Swiglu < Self :: Hardware > ;
35+ type Scale : Scale < Self :: Hardware > ;
3336 type Rearrange : Rearrange < Self :: Hardware > ;
3437 type AllReduce : AllReduce < Self :: Hardware , Self :: TopoNode > ;
3538
@@ -82,6 +85,7 @@ pub struct Minicpm3Worker<Ops: Operators, W> {
8285 rope : Ops :: Rope ,
8386 rms_norm : Ops :: RmsNorm ,
8487 mat_mul : Ops :: MatMul ,
88+ scale : Ops :: Scale ,
8589 swiglu : Ops :: Swiglu ,
8690 rearrange : Ops :: Rearrange ,
8791 all_reduce : Ops :: AllReduce ,
@@ -98,6 +102,7 @@ impl<Ops: Operators, W> Minicpm3Worker<Ops, W> {
98102 rope : Ops :: Rope :: new ( processor) ,
99103 rms_norm : Ops :: RmsNorm :: new ( processor) ,
100104 mat_mul : Ops :: MatMul :: new ( processor) ,
105+ scale : Ops :: Scale :: new ( processor) ,
101106 swiglu : Ops :: Swiglu :: new ( processor) ,
102107 rearrange : Ops :: Rearrange :: new ( processor) ,
103108 add : Ops :: Add :: new ( processor) ,
@@ -154,18 +159,7 @@ where
154159 let scale_depth = 1.4f32 ;
155160 // 残差连接时权重缩放
156161 let s = scale_depth / ( nblk as f32 ) . sqrt ( ) ;
157- fn ggml_scale ( embd : * mut f16 , s : f16 , l : usize ) {
158- if l == 0 {
159- return ;
160- } // 如果长度为 0,则无需进行任何操作
161-
162- unsafe {
163- let slice = std:: slice:: from_raw_parts_mut ( embd, l) ;
164- slice. iter_mut ( ) . for_each ( |x| * x *= s) ;
165- }
166- }
167-
168- ggml_scale ( x. base_mut ( ) . cast :: < f16 > ( ) , f16:: from_f32 ( scale_emb) , d) ;
162+
169163
170164 let dnope = dk - dh;
171165 let tensor = |shape : & [ usize ] | Tensor :: new ( dt_embd, shape) ;
@@ -184,6 +178,10 @@ where
184178 let mut attn = attn. map ( |_| buf) ;
185179
186180 let queue = queue_alloc. queue ( ) ;
181+ // 缩放
182+ let inplace=unsafe {
183+ x. map_slice_static ( ) } ;
184+ self . scale ( & mut x, & inplace, scale_emb, workspace, queue_alloc) ?;
187185 for iblk in 0 ..nblk {
188186 // norm
189187 let w = self . weights . attn_norm ( iblk, queue) ;
@@ -619,6 +617,31 @@ where
619617 queue_alloc,
620618 )
621619 }
620+ fn scale < C , A , QA > (
621+ & self ,
622+ c : & mut Tensor < C > ,
623+ a : & Tensor < A > ,
624+ s : f32 ,
625+ workspace : & mut [ ByteOf < Ops :: Hardware > ] ,
626+ queue_alloc : & QA ,
627+ ) -> Result < ( ) , LaunchError >
628+ where
629+ C : DerefMut < Target = [ ByteOf < Ops :: Hardware > ] > ,
630+ A : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
631+ QA : QueueAlloc < Hardware = Ops :: Hardware > ,
632+ {
633+ self . scale . launch (
634+ & scale:: Args {
635+ c_layout : c. layout ( ) ,
636+ c_base : c. base_mut ( ) ,
637+ a_layout : a. layout ( ) ,
638+ a_base : a. base ( ) ,
639+ s,
640+ } ,
641+ workspace,
642+ queue_alloc,
643+ )
644+ }
622645 fn all_reduce < X , QA > (
623646 & self ,
624647 x : & mut Tensor < X > ,
0 commit comments