@@ -33,12 +33,10 @@ pub trait Operators {
3333 T : Deref < Target = [ ByteOf < Self :: Hardware > ] > ;
3434
3535 fn memcpy_d2h < T : Copy > (
36- _dst : & mut [ T ] ,
37- _src : & [ ByteOf < Self :: Hardware > ] ,
38- _queue : & QueueOf < Self :: Hardware > ,
39- ) {
40- todo ! ( )
41- }
36+ dst : & mut [ T ] ,
37+ src : & [ ByteOf < Self :: Hardware > ] ,
38+ queue : & QueueOf < Self :: Hardware > ,
39+ ) ;
4240
4341 fn build_sin_cos < QA > (
4442 dt : DigitLayout ,
@@ -81,13 +79,11 @@ pub trait WeightLoader {
8179
8280 fn load_moe < ' a > (
8381 & ' a self ,
84- _which : BlkWeight ,
85- _iblk : usize ,
86- _iexp : usize ,
87- _queue : & ' a QueueOf < Self :: Hardware > ,
88- ) -> Self :: Weight < ' a > {
89- todo ! ( )
90- }
82+ which : BlkWeight ,
83+ iblk : usize ,
84+ iexp : usize ,
85+ queue : & ' a QueueOf < Self :: Hardware > ,
86+ ) -> Self :: Weight < ' a > ;
9187
9288 fn output_norm < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> Self :: Weight < ' a > ;
9389 fn output < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> Self :: Weight < ' a > ;
@@ -105,17 +101,10 @@ pub struct LlamaWorker<Ops: Operators, W> {
105101 swiglu : Ops :: Swiglu ,
106102 rearrange : Ops :: Rearrange ,
107103 all_reduce : Ops :: AllReduce ,
108- residual : bool ,
109104}
110105
111106impl < Ops : Operators , W > LlamaWorker < Ops , W > {
112- pub fn new (
113- id : usize ,
114- node : & Ops :: TopoNode ,
115- meta : LlamaMeta ,
116- weights : W ,
117- residual : bool ,
118- ) -> Self {
107+ pub fn new ( id : usize , node : & Ops :: TopoNode , meta : LlamaMeta , weights : W ) -> Self {
119108 let processor = node. processor ( ) ;
120109 Self {
121110 id,
@@ -128,7 +117,6 @@ impl<Ops: Operators, W> LlamaWorker<Ops, W> {
128117 swiglu : Ops :: Swiglu :: new ( processor) ,
129118 rearrange : Ops :: Rearrange :: new ( processor) ,
130119 all_reduce : Ops :: AllReduce :: new ( node) ,
131- residual,
132120 }
133121 }
134122
@@ -199,7 +187,6 @@ where
199187 di,
200188 ..
201189 } = self . meta ;
202- let residual = if self . residual { 1. } else { 0. } ;
203190
204191 let workspace_size = self . workspace_size ( nt, max_seq_len, max_att_len) ;
205192 let mut workspace = Workspace :: new ( queue_alloc, workspace, workspace_size) ;
@@ -289,6 +276,7 @@ where
289276
290277 let o = q. merge ( 1 ..3 ) . unwrap ( ) ;
291278 let w = self . weights . attn_o ( iblk, queue) ;
279+ let residual = if self . id == 0 { 1. } else { 0. } ;
292280 self . mat_mul ( & mut x, residual, & o, & w, 1. , workspace, queue_alloc) ?
293281 }
294282 self . all_reduce ( & mut x, workspace, queue_alloc) ?;
@@ -310,6 +298,7 @@ where
310298 self . swiglu ( & mut gate, & up, workspace, queue_alloc) ?;
311299
312300 let w = self . weights . ffn_down ( iblk, 0 , queue) ;
301+ let residual = if self . id == 0 { 1. } else { 0. } ;
313302 self . mat_mul ( & mut x, residual, & gate, & w, 1. , workspace, queue_alloc) ?
314303 } else {
315304 let mut routes_host = routes. clone ( ) . map ( Blob :: new) . take ( ) ;
@@ -336,6 +325,7 @@ where
336325 for ( mut x, x1) in izip ! ( x, x1) {
337326 let ( line, tail) = routes. split_at ( nexp) ;
338327 routes = tail;
328+ let mut first = true ;
339329 for ( iexp, kexp) in self . topk_with_index ( line) {
340330 let w = self . weights . ffn_gate_up ( iblk, iexp, queue) ;
341331 self . mat_mul ( & mut gate_up, 0. , & x1, & w, 1. , workspace, queue_alloc) ?;
@@ -346,7 +336,9 @@ where
346336 self . swiglu ( & mut gate, & up, workspace, queue_alloc) ?;
347337
348338 let w = self . weights . ffn_down ( iblk, iexp, queue) ;
349- self . mat_mul ( & mut x, residual, & gate, & w, kexp, workspace, queue_alloc) ?
339+ let residual = if self . id == 0 || !first { 1. } else { 0. } ;
340+ self . mat_mul ( & mut x, residual, & gate, & w, kexp, workspace, queue_alloc) ?;
341+ first = false
350342 }
351343 }
352344 }
0 commit comments