@@ -3,6 +3,8 @@ use gguf::ggml_quants::digit_layout::types as ty;
33use gguf:: ggml_quants:: digit_layout:: DigitLayout ;
44use half:: f16;
55use itertools:: Itertools ;
6+ use operators:: fuesd_softmax;
7+ use operators:: fuesd_softmax:: FusedSoftmax ;
68use operators:: scale;
79use operators:: scale:: Scale ;
810use operators:: {
@@ -19,6 +21,7 @@ use operators::{
1921 ByteOf , Hardware , LaunchError , Operator , QueueAlloc , QueueOf , TopoNode , Workspace ,
2022} ;
2123use std:: ops:: { Deref , DerefMut } ;
24+ use std:: process:: Output ;
2225use tensor:: split_mut;
2326use tensor:: { split, Tensor } ;
2427
@@ -33,6 +36,7 @@ pub trait Operators {
3336 type MatMul : MatMul < Self :: Hardware > ;
3437 type Swiglu : Swiglu < Self :: Hardware > ;
3538 type Scale : Scale < Self :: Hardware > ;
39+ type FuesdSoftmax : FusedSoftmax < Self :: Hardware > ;
3640 type Rearrange : Rearrange < Self :: Hardware > ;
3741 type AllReduce : AllReduce < Self :: Hardware , Self :: TopoNode > ;
3842
@@ -87,6 +91,7 @@ pub struct Minicpm3Worker<Ops: Operators, W> {
8791 mat_mul : Ops :: MatMul ,
8892 scale : Ops :: Scale ,
8993 swiglu : Ops :: Swiglu ,
94+ fuesd_softmax : Ops :: FuesdSoftmax ,
9095 rearrange : Ops :: Rearrange ,
9196 all_reduce : Ops :: AllReduce ,
9297}
@@ -109,6 +114,7 @@ impl<Ops: Operators, W> Minicpm3Worker<Ops, W> {
109114 all_reduce : Ops :: AllReduce :: new ( node) ,
110115 dt_pos : ty:: U64 ,
111116 attention : Ops :: Attention :: new ( processor) ,
117+ fuesd_softmax : Ops :: FuesdSoftmax :: new ( processor) ,
112118 }
113119 }
114120
@@ -165,12 +171,11 @@ where
165171
166172 let gate_up = tensor ( & [ nt, di * 2 ] ) ;
167173 // 空间 x+x1+q(应该可以删除)+q3+kv_pe+attn
168- let workspace_size = * x1. get ( ) * 3 + * gate_up. get ( ) ;
174+ let workspace_size = * x1. get ( ) * 20 + * gate_up. get ( ) ;
169175 let mut workspace = Workspace :: new ( queue_alloc, workspace, workspace_size) ;
170176 let ( buf, workspace) = workspace. split_at_mut ( * x1. get ( ) ) ;
171177 let mut x1 = x1. map ( |_| buf) ;
172178
173-
174179 let queue = queue_alloc. queue ( ) ;
175180
176181 let sin = sin_cos. clone ( ) . index ( 0 , 0 ) ;
@@ -205,17 +210,15 @@ where
205210 let w = self . weights . attn_qa_norm ( iblk, queue) ;
206211 self . rms_norm ( & mut q, & inplace, & w, workspace, queue_alloc) ?;
207212 {
208- // q [1, 768] q1 [1, 3840] kv_pe [1,288] kv [1, 5120] k [1, 3840] attn [1, 2560]
209213 let q1 = tensor ( & [ nt, nh * dk] ) ;
210214 let ( buf, workspace) = workspace. split_at_mut ( * q1. get ( ) ) ;
211215 let mut q1 = q1. map ( |_| buf) ;
212216 let w = self . weights . attn_qb ( iblk, queue) . transpose ( & [ 1 , 0 ] ) ;
213217 self . mat_mul ( & mut q1, 0. , & q, & w, 1. , workspace, queue_alloc) ?;
214- drop ( q) ;
215- // q3 是计算 attn 需要用到的数据,但是我们仍然需要对 q3 的的部分进行嵌入操作
218+
216219 let mut q3 = q1. tile ( 1 , & [ nh, dk] ) ;
217220 let q2 = unsafe { q3. map_slice_static_mut ( ) } ;
218- split_mut ! ( q2=>_q , q_rope; [ dnope, dh] @ 2 ) ;
221+ split_mut ! ( q2=>q_nope , q_rope; [ dnope, dh] @ 2 ) ;
219222
220223 // kv_pe [1,288]
221224 let kv_pe = tensor ( & [ nt, dkv_lora + dh] ) ;
@@ -224,65 +227,125 @@ where
224227
225228 let w = self . weights . attn_kva ( iblk, queue) . transpose ( & [ 1 , 0 ] ) ;
226229 self . mat_mul ( & mut kv_pe, 0. , & x1, & w, 1. , workspace, queue_alloc) ?;
227-
230+ drop ( q ) ;
228231 split_mut ! ( kv_pe => kv_lora, k_rope; [ dkv_lora, dh] @ 1 ) ;
229232
233+ self . rope ( & mut q_rope, & pos, & sin, & cos, workspace, queue_alloc) ?;
234+ let mut k_rope = k_rope. tile ( 1 , & [ 1 , dh] ) ;
235+ self . rope ( & mut k_rope, & pos, & sin, & cos, workspace, queue_alloc) ?;
236+ let k_rope = k_rope. broadcast ( 1 , nh) ;
237+
230238 let inplace = unsafe { kv_lora. map_slice_static ( ) } ;
231239 let w = self . weights . attn_kva_norm ( iblk, queue) ;
232240 self . rms_norm ( & mut kv_lora, & inplace, & w, workspace, queue_alloc) ?;
233241 // kv X[1, 5120]
234242 let kv = tensor ( & [ nt, nh * ( dnope + dv) ] ) ;
235243 let ( buf, workspace) = workspace. split_at_mut ( * kv. get ( ) ) ;
236244 let mut kv = kv. map ( |_| buf) ;
237- let w = self . weights . attn_kvb ( iblk, queue) . transpose ( & [ 1 , 0 ] ) ;
238-
239- self . mat_mul ( & mut kv, 0. , & kv_lora, & w, 1. , workspace, queue_alloc) ?;
240-
241- let kv = kv. tile ( 1 , & [ nh, dnope + dv] ) ;
242-
243- split_mut ! ( kv => k_nope , v ; [ dnope , dv ] @ 2 ) ;
244245
245- // k [1, 3840]
246- let k = tensor ( & [ nt, nh, dk] ) ;
247- let ( buf, workspace) = workspace. split_at_mut ( * k. get ( ) ) ;
248- let k = k. map ( |_| buf) ;
249-
250- split_mut ! ( k => k_nope_r , k_rope_r ; [ dnope, dh] @ 2 ) ;
251-
252- self . rope ( & mut q_rope, & pos, & sin, & cos, workspace, queue_alloc) ?;
253- let mut k_rope = k_rope. tile ( 1 , & [ 1 , dh] ) ;
254- self . rope ( & mut k_rope, & pos, & sin, & cos, workspace, queue_alloc) ?;
255- let k_rope = k_rope. broadcast ( 1 , nh) ;
256- self . rearrange ( & mut k_rope_r, & k_rope, workspace, queue_alloc) ?;
257- self . rearrange ( & mut k_nope_r, & k_nope, workspace, queue_alloc) ?;
258-
259- let pos = requests. last ( ) . unwrap ( ) . pos as f32 ;
260- let mut q = q3. transpose ( & [ 1 , 0 ] ) ;
261- let k = k. map_slice ( ) . transpose ( & [ 1 , 0 ] ) ;
262- let v = v. map_slice_mut ( ) . transpose ( & [ 1 , 0 ] ) ;
263- // 经行 attention
264- let attn = tensor ( & [ nt, nh, dv] ) ;
265- let ( buf, workspace) = workspace. split_at_mut ( * attn. get ( ) ) ;
266- let mut attn = attn. map ( |_| buf) ;
267-
268- let mut attn = unsafe { attn. map_slice_mut ( ) . transpose ( & [ 1 , 0 ] ) } ;
269- let pos = requests. last ( ) . unwrap ( ) . pos as f32 ;
270- self . attnention (
271- & mut q,
272- & k,
273- & v,
274- & mut attn,
275- pos as usize ,
246+ let kv_b_proj = unsafe {
247+ self . weights
248+ . attn_kvb ( iblk, queue)
249+ . tile ( 0 , & [ nh, dnope + dv] )
250+ . map_slice_static ( )
251+ } ;
252+ split ! ( kv_b_proj=> q_absorb , out_absorb ; [ dnope, dv] @ 1 ) ;
253+ let inplace = unsafe { q_nope. map_slice_static ( ) } ;
254+
255+ let q_nope_0 = q_nope. map_slice ( ) . transpose ( & [ 1 , 0 ] ) ;
256+ let q_nope_1 = tensor ( & [ nh, nt, dkv_lora] ) ;
257+ let ( buf, workspace) = workspace. split_at_mut ( * q_nope_1. get ( ) ) ;
258+ let mut q_nope = q_nope_1. map ( |_| buf) ;
259+ self . mat_mul (
260+ & mut q_nope,
261+ 0. ,
262+ & q_nope_0,
263+ & q_absorb,
264+ 1. ,
276265 workspace,
277266 queue_alloc,
278267 ) ?;
279268
280- let o = attn. transpose ( & [ 1 , 0 ] ) . merge ( 1 ..3 ) . unwrap ( ) ;
269+ drop ( q3) ;
270+ //计算 attn_weights
271+ // todo deepseek 中会有 softmax_scale
272+ // python 代码
273+ // attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
274+ let attn_weights = tensor ( & [ nh, nt, nt] ) ;
275+ let ( buf, workspace) = workspace. split_at_mut ( * attn_weights. get ( ) ) ;
276+ let mut attn_weights = attn_weights. map ( |_| buf) ;
277+ {
278+ let q_rope = q_rope. transpose ( & [ 1 , 0 ] ) ;
279+ let k_rope = k_rope. transpose ( & [ 1 , 2 , 0 ] ) ;
280+
281+ self . mat_mul (
282+ & mut attn_weights,
283+ 0. ,
284+ & q_rope,
285+ & k_rope,
286+ 1. ,
287+ workspace,
288+ queue_alloc,
289+ ) ?;
290+ let kv_lora = kv_lora
291+ . map_slice ( )
292+ . tile ( 0 , & [ 1 , 1 ] )
293+ . broadcast ( 0 , nh)
294+ . transpose ( & [ 2 , 1 ] ) ;
295+ self . mat_mul (
296+ & mut attn_weights,
297+ 1. ,
298+ & q_nope,
299+ & kv_lora,
300+ 1. ,
301+ workspace,
302+ queue_alloc,
303+ ) ?;
304+ }
305+ // softmax
306+ self . softmax ( & mut attn_weights, workspace, queue_alloc) ?;
307+ // attn_output
308+ let attn_output_r = tensor ( & [ nt, nh, dv] ) ;
309+ let ( buf, workspace) = workspace. split_at_mut ( * attn_output_r. get ( ) ) ;
310+ let mut attn_output_r = attn_output_r. map ( |_| buf) ;
311+ {
312+ let attn_output = tensor ( & [ nh, nt, dkv_lora] ) ;
313+ let ( buf, workspace) = workspace. split_at_mut ( * attn_output. get ( ) ) ;
314+ let mut attn_output = attn_output. map ( |_| buf) ;
315+ let kv_lora = kv_lora. tile ( 0 , & [ 1 , 1 ] ) . broadcast ( 0 , nh) ;
316+ self . mat_mul (
317+ & mut attn_output,
318+ 0. ,
319+ & attn_weights,
320+ & kv_lora,
321+ 1. ,
322+ workspace,
323+ queue_alloc,
324+ ) ?;
325+ let mut attn_output_r = attn_output_r. map_slice_mut ( ) . transpose ( & [ 1 , 0 ] ) ;
326+ let out_absorb = out_absorb. transpose ( & [ 2 , 1 ] ) ;
327+
328+ self . mat_mul (
329+ & mut attn_output_r,
330+ 0. ,
331+ & attn_output,
332+ & out_absorb,
333+ 1. ,
334+ workspace,
335+ queue_alloc,
336+ ) ?;
337+ }
338+ Ops :: debug ( & attn_output_r, queue) ;
339+ todo ! ( ) ;
340+ // let o = attn_output_r;
281341 let w = self . weights . attn_o ( iblk, queue) ;
282-
283- self . mat_mul ( & mut x1, 0. , & o, & w, s, workspace, queue_alloc) ?;
284- let inplace = unsafe { x. map_slice_static ( ) } ;
285- self . add ( & mut x, & inplace, & x1, workspace, queue_alloc) ?;
342+ println ! ( "{:?}" , attn_output_r. shape( ) ) ;
343+ println ! ( "{:?}" , w. shape( ) ) ;
344+ // println!("{:?}",out_absorb.shape());
345+ todo ! ( ) ;
346+ // self.mat_mul(&mut x1, 0., &o, &w, s, workspace, queue_alloc)?;
347+ // let inplace = unsafe { x.map_slice_static() };
348+ // self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?;
286349 }
287350 let w = self . weights . ffn_norm ( iblk, queue) ;
288351 self . rms_norm ( & mut x1, & x, & w, workspace, queue_alloc) ?;
@@ -594,6 +657,26 @@ where
594657 queue_alloc,
595658 )
596659 }
660+ fn softmax < A , QA > (
661+ & self ,
662+ a : & mut Tensor < A > ,
663+ workspace : & mut [ ByteOf < Ops :: Hardware > ] ,
664+ queue_alloc : & QA ,
665+ ) -> Result < ( ) , LaunchError >
666+ where
667+ A : DerefMut < Target = [ ByteOf < Ops :: Hardware > ] > ,
668+ QA : QueueAlloc < Hardware = Ops :: Hardware > ,
669+ {
670+ self . fuesd_softmax . launch (
671+ & fuesd_softmax:: Args {
672+ att_mask : AttnMask :: Causal ,
673+ att_layout : a. layout ( ) ,
674+ att_base : a. base_mut ( ) ,
675+ } ,
676+ workspace,
677+ queue_alloc,
678+ )
679+ }
597680 fn all_reduce < X , QA > (
598681 & self ,
599682 x : & mut Tensor < X > ,
0 commit comments