11use super :: { args:: Args , LlamaMeta } ;
2- use gguf:: ggml_quants:: digit_layout:: { types as ty, DigitLayout } ;
2+ use gguf:: ggml_quants:: {
3+ digit_layout:: { types as ty, DigitLayout } ,
4+ f16,
5+ } ;
36use itertools:: izip;
47use operators:: {
58 all_reduce:: { self , AllReduce , ReduceOp } ,
@@ -12,7 +15,7 @@ use operators::{
1215 ByteOf , Hardware , LaunchError , Operator , QueueAlloc , QueueOf , TopoNode , Workspace ,
1316} ;
1417use std:: ops:: { Deref , DerefMut } ;
15- use tensor:: { split, Tensor } ;
18+ use tensor:: { split, Blob , Tensor } ;
1619
1720pub trait Operators {
1821 type Hardware : Hardware ;
@@ -29,6 +32,14 @@ pub trait Operators {
2932 where
3033 T : Deref < Target = [ ByteOf < Self :: Hardware > ] > ;
3134
35+ fn memcpy_d2h < T : Copy > (
36+ _dst : & mut [ T ] ,
37+ _src : & [ ByteOf < Self :: Hardware > ] ,
38+ _queue : & QueueOf < Self :: Hardware > ,
39+ ) {
40+ todo ! ( )
41+ }
42+
3243 fn build_sin_cos < QA > (
3344 dt : DigitLayout ,
3445 nctx : usize ,
@@ -68,6 +79,16 @@ pub trait WeightLoader {
6879 queue : & ' a QueueOf < Self :: Hardware > ,
6980 ) -> Self :: Weight < ' a > ;
7081
82+ fn load_moe < ' a > (
83+ & ' a self ,
84+ _which : BlkWeight ,
85+ _iblk : usize ,
86+ _iexp : usize ,
87+ _queue : & ' a QueueOf < Self :: Hardware > ,
88+ ) -> Self :: Weight < ' a > {
89+ todo ! ( )
90+ }
91+
7192 fn output_norm < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> Self :: Weight < ' a > ;
7293 fn output < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> Self :: Weight < ' a > ;
7394}
@@ -118,23 +139,30 @@ impl<Ops: Operators, W> LlamaWorker<Ops, W> {
118139
119140 pub fn workspace_size ( & self , nt : usize , max_seq_len : usize , max_att_len : usize ) -> usize {
120141 let LlamaMeta {
121- dt_mat,
122142 nh,
123143 nkvh,
124- d ,
144+ nexp ,
125145 dh,
126146 di,
127147 ..
128148 } = self . meta ;
129149
130- let ele = dt_mat. nbytes ( ) ;
131- let embd = nt * d * ele;
132- let qkv = nt * ( nh + nkvh + nkvh) * dh * ele;
133- let gate_up = nt * di * 2 * ele;
134- let q = max_seq_len * nh * dh * ele;
135- let att = nkvh * max_seq_len * max_att_len * ele;
136-
137- embd + qkv. max ( gate_up) + q + att
150+ let embd = self . meta . embd ( nt) ;
151+ let dt = embd. dt ( ) ;
152+ let embd = embd. take ( ) ;
153+
154+ let qkv = Tensor :: new ( dt, & [ nt * ( nh + nkvh + nkvh) , dh] ) . take ( ) ;
155+ let q = Tensor :: new ( dt, & [ max_seq_len, nh, dh] ) . take ( ) ;
156+ let att = Tensor :: new ( dt, & [ nkvh, max_seq_len, max_att_len] ) . take ( ) ;
157+
158+ if self . meta . is_moe ( ) {
159+ let routes = Tensor :: new ( dt, & [ nt, nexp] ) . take ( ) ;
160+ let gate_up = Tensor :: new ( dt, & [ 1 , di * 2 ] ) . take ( ) ;
161+ embd + ( qkv + q + att) . max ( routes) . max ( gate_up)
162+ } else {
163+ let gate_up = Tensor :: new ( dt, & [ nt, di * 2 ] ) . take ( ) ;
164+ embd + ( qkv + q + att) . max ( gate_up)
165+ }
138166 }
139167}
140168
@@ -167,6 +195,7 @@ where
167195 nh,
168196 nkvh,
169197 nexp,
198+ nexp_use,
170199 dh,
171200 di,
172201 ..
@@ -182,7 +211,7 @@ where
182211 let mut x1 = x1. map ( |_| buf) ;
183212
184213 let qkv = Tensor :: new ( x. dt ( ) , & [ nt, ( nh + nkvh + nkvh) * dh] ) ;
185- let gate_up = Tensor :: new ( x. dt ( ) , & [ nt , di * 2 ] ) ;
214+ let gate_up = Tensor :: new ( x. dt ( ) , & [ if self . meta . is_moe ( ) { 1 } else { nt } , di * 2 ] ) ;
186215 let routes = Tensor :: new ( x. dt ( ) , & [ nt, nexp] ) ;
187216
188217 let sin = sin_cos. clone ( ) . index ( 0 , 0 ) ;
@@ -264,36 +293,81 @@ where
264293 }
265294 self . all_reduce ( & mut x, workspace, queue_alloc) ?;
266295
267- if !self . meta . is_moe ( ) {
268- let w = self . weights . ffn_norm ( iblk, queue) ;
269- self . rms_norm ( & mut x1, & x, & w, workspace, queue_alloc) ?;
270- drop ( w) ;
296+ let w = self . weights . ffn_norm ( iblk, queue) ;
297+ self . rms_norm ( & mut x1, & x, & w, workspace, queue_alloc) ?;
298+ drop ( w) ;
271299
300+ if !self . meta . is_moe ( ) {
272301 let ( buf, workspace) = workspace. split_at_mut ( * gate_up. get ( ) ) ;
273302 let mut gate_up = gate_up. clone ( ) . map ( |_| buf) ;
274303
275- let w = self . weights . ffn_gate_up ( iblk, queue) ;
304+ let w = self . weights . ffn_gate_up ( iblk, 0 , queue) ;
276305 self . mat_mul ( & mut gate_up, 0. , & x1, & w, 1. , workspace, queue_alloc) ?;
277306 drop ( w) ;
278307
279308 split ! ( gate_up => gate, up; [ di, di] @ 1 ) ;
280309 let mut gate = gate;
281310 self . swiglu ( & mut gate, & up, workspace, queue_alloc) ?;
282311
283- let w = self . weights . ffn_down ( iblk, queue) ;
312+ let w = self . weights . ffn_down ( iblk, 0 , queue) ;
284313 self . mat_mul ( & mut x, residual, & gate, & w, 1. , workspace, queue_alloc) ?
285314 } else {
315+ let mut routes_host = routes. clone ( ) . map ( Blob :: new) . take ( ) ;
316+ // gate_inp
286317 {
287318 let ( buf, workspace) = workspace. split_at_mut ( * routes. get ( ) ) ;
288- let mut routes = routes. clone ( ) . map ( |_| buf) ;
319+ let mut routes_dev = routes. clone ( ) . map ( |_| buf) ;
289320
290321 let w = self . weights . ffn_gate_inp ( iblk, queue) ;
291- self . mat_mul ( & mut routes , 0. , & x , & w, 1. , workspace, queue_alloc) ?;
322+ self . mat_mul ( & mut routes_dev , 0. , & x1 , & w, 1. , workspace, queue_alloc) ?;
292323 drop ( w) ;
293324
294- todo ! ( )
325+ Ops :: memcpy_d2h ( & mut routes_host, routes_dev. get ( ) , queue)
326+ }
327+ let ( [ ] , routes, [ ] ) = ( unsafe { routes_host. align_to_mut :: < f16 > ( ) } ) else {
328+ unreachable ! ( )
329+ } ;
330+
331+ for itok in ( 0 ..nt) . rev ( ) {
332+ // fused topk
333+ let mut routes = routes[ itok * nexp..] [ ..nexp]
334+ . iter ( )
335+ . copied ( )
336+ . enumerate ( )
337+ . collect :: < Vec < _ > > ( ) ;
338+
339+ routes. sort_unstable_by ( |& ( _, a) , & ( _, b) | b. total_cmp ( & a) ) ;
340+ let max = routes[ 0 ] . 1 . to_f32 ( ) ;
341+ let mut sum = 0. ;
342+ let mut moe_gate = vec ! [ ( 0 , 0.0f32 ) ; nexp_use] ;
343+ for ( ( i, x) , gate) in std:: iter:: zip ( routes, & mut moe_gate) {
344+ let softmax = ( x. to_f32 ( ) - max) . exp ( ) ;
345+ * gate = ( i, softmax) ;
346+ sum += softmax
347+ }
348+ for ( _, x) in & mut moe_gate {
349+ * x /= sum
350+ }
351+ // mlp
352+ let ( buf, workspace) = workspace. split_at_mut ( * gate_up. get ( ) ) ;
353+ let mut gate_up = gate_up. clone ( ) . map ( |_| buf) ;
354+
355+ let mut x = x. map_slice_mut ( ) . slice ( 0 , itok, 0 , 1 ) ;
356+ let x1 = x1. map_slice_mut ( ) . slice ( 0 , itok, 0 , 1 ) ;
357+
358+ for ( iexp, kexp) in moe_gate {
359+ let w = self . weights . ffn_gate_up ( iblk, iexp, queue) ;
360+ self . mat_mul ( & mut gate_up, 0. , & x1, & w, 1. , workspace, queue_alloc) ?;
361+ drop ( w) ;
362+
363+ split ! ( gate_up => gate, up; [ di, di] @ 1 ) ;
364+ let mut gate = gate;
365+ self . swiglu ( & mut gate, & up, workspace, queue_alloc) ?;
366+
367+ let w = self . weights . ffn_down ( iblk, iexp, queue) ;
368+ self . mat_mul ( & mut x, residual, & gate, & w, kexp, workspace, queue_alloc) ?
369+ }
295370 }
296- // TODO MLP
297371 }
298372 self . all_reduce ( & mut x, workspace, queue_alloc) ?
299373 }
@@ -553,6 +627,7 @@ struct WeightDecorator<W> {
553627 ffn_gate_up : Tensor < usize > ,
554628 ffn_down : Tensor < usize > ,
555629 output : Tensor < usize > ,
630+ is_moe : bool ,
556631 weights : W ,
557632}
558633
@@ -567,6 +642,7 @@ impl LlamaMeta {
567642 ffn_gate_up : self . ffn_gate_up ( Computation ) ,
568643 ffn_down : self . ffn_down ( Computation ) ,
569644 output : self . output ( ) ,
645+ is_moe : self . is_moe ( ) ,
570646 weights,
571647 }
572648 }
@@ -627,19 +703,31 @@ impl<W: WeightLoader> WeightDecorator<W> {
627703 pub fn ffn_gate_up < ' a > (
628704 & ' a self ,
629705 iblk : usize ,
706+ iexp : usize ,
630707 queue : & ' a QueueOf < W :: Hardware > ,
631708 ) -> Tensor < W :: Weight < ' a > > {
632- let w = self . weights . load_blk ( BlkWeight :: FfnGateUp , iblk, queue) ;
709+ const WHICH : BlkWeight = BlkWeight :: FfnGateUp ;
710+ let w = if self . is_moe {
711+ self . weights . load_moe ( WHICH , iblk, iexp, queue)
712+ } else {
713+ self . weights . load_blk ( WHICH , iblk, queue)
714+ } ;
633715 self . ffn_gate_up . clone ( ) . map ( |_| w)
634716 }
635717
636718 #[ inline]
637719 pub fn ffn_down < ' a > (
638720 & ' a self ,
639721 iblk : usize ,
722+ iexp : usize ,
640723 queue : & ' a QueueOf < W :: Hardware > ,
641724 ) -> Tensor < W :: Weight < ' a > > {
642- let w = self . weights . load_blk ( BlkWeight :: FfnDown , iblk, queue) ;
725+ const WHICH : BlkWeight = BlkWeight :: FfnDown ;
726+ let w = if self . is_moe {
727+ self . weights . load_moe ( WHICH , iblk, iexp, queue)
728+ } else {
729+ self . weights . load_blk ( WHICH , iblk, queue)
730+ } ;
643731 self . ffn_down . clone ( ) . map ( |_| w)
644732 }
645733
0 commit comments