@@ -23,6 +23,7 @@ pub trait Operators {
2323 type AllReduce : AllReduce < Self :: Hardware , Self :: TopoNode > ;
2424 type AddRows : AddRows < Self :: Hardware > ;
2525 type Mlp : Gpt2Mlp < Self :: Hardware > ;
26+
2627 fn debug < T > ( tensor : & Tensor < T > )
2728 where
2829 T : Deref < Target = [ ByteOf < Self :: Hardware > ] > ;
@@ -66,6 +67,7 @@ pub struct Gpt2Worker<Ops: Operators, W> {
6667 all_reduce : Ops :: AllReduce ,
6768 add_rows : Ops :: AddRows ,
6869 mlp : Ops :: Mlp ,
70+ pub debug : bool ,
6971}
7072
7173impl < Ops : Operators , W > Gpt2Worker < Ops , W > {
@@ -81,6 +83,7 @@ impl<Ops: Operators, W> Gpt2Worker<Ops, W> {
8183 all_reduce : Ops :: AllReduce :: new ( node) ,
8284 add_rows : Ops :: AddRows :: new ( processor) ,
8385 mlp : Ops :: Mlp :: new ( processor) ,
86+ debug : true ,
8487 }
8588 }
8689
@@ -136,7 +139,6 @@ where
136139 idx,
137140 idx_add,
138141 } = args;
139-
140142 let Gpt2Meta {
141143 dt_embd,
142144 nblk,
@@ -145,6 +147,7 @@ where
145147 dh,
146148 ..
147149 } = self . meta ;
150+
148151 let workspace_size = self . workspace_size ( nt, max_seq_len, max_att_len) ;
149152 let mut workspace = Workspace :: new ( queue_alloc, workspace, workspace_size) ;
150153 let queue = queue_alloc. queue ( ) ;
@@ -161,7 +164,7 @@ where
161164 token_embd = token_embd. merge ( 0 ..2 ) . unwrap ( ) ;
162165 }
163166 let mut x = token_embd;
164- let x1 = Tensor :: new ( dt_embd , x. shape ( ) ) ;
167+ let x1 = Tensor :: new ( x . dt ( ) , x. shape ( ) ) ;
165168 let ( buf, workspace) = workspace. split_at_mut ( * x1. get ( ) ) ;
166169 let mut x1 = x1. map ( |_| buf) ;
167170 let qkv = Tensor :: new ( dt_embd, & [ nt, ( nh + nkvh + nkvh) * dh] ) ;
@@ -177,10 +180,9 @@ where
177180 let mut qkv = qkv. clone ( ) . map ( |_| buf) ;
178181 {
179182 let [ scale, bias] = self . weights . attn_qkv ( iblk, queue) ;
180- let cols = bias. shape ( ) [ 0 ] ;
181- let bias = bias. tile ( 0 , & [ 1 , cols] ) . broadcast ( 0 , nt) ;
183+ let bias = bias. broadcast ( 0 , nt) ;
182184 self . rearrange ( & mut qkv, & bias, workspace, queue_alloc) ?;
183- self . mat_mul ( & mut qkv, 1. , & x1, & scale, 1. , workspace, queue_alloc) ?;
185+ self . mat_mul ( & mut qkv, 1. , & x1, & scale, 1. , workspace, queue_alloc) ?
184186 }
185187 let qkv = qkv. tile ( 1 , & [ nh + nkvh + nkvh, dh] ) ;
186188 split ! ( qkv => q, k, v; [ nh, nkvh, nkvh] @ 1 ) ;
@@ -215,14 +217,13 @@ where
215217 req. pos ,
216218 workspace,
217219 queue_alloc,
218- ) ?;
220+ ) ?
219221 }
220222 }
221223 {
222224 let o = q. map_slice ( ) . merge ( 1 ..3 ) . unwrap ( ) ;
223225 let [ scale, bias] = self . weights . attn_o ( iblk, queue) ;
224- let cols = bias. shape ( ) [ 0 ] ;
225- let bias = bias. tile ( 0 , & [ 1 , cols] ) . broadcast ( 0 , nt) ;
226+ let bias = bias. broadcast ( 0 , nt) ;
226227 self . rearrange ( & mut x1, & bias, workspace, queue_alloc) ?;
227228 self . mat_mul ( & mut x1, 1. , & o, & scale, 1. , workspace, queue_alloc) ?;
228229 }
@@ -506,50 +507,40 @@ where
506507}
507508
508509struct WeightDecorator < W > {
509- attn_norm_w : Tensor < usize > ,
510- attn_norm_b : Tensor < usize > ,
510+ pos_embd : Tensor < usize > ,
511+ output_weight : Tensor < usize > ,
512+ norm : Tensor < usize > ,
513+
511514 attn_qkv_w : Tensor < usize > ,
512515 attn_qkv_b : Tensor < usize > ,
513516 attn_o_w : Tensor < usize > ,
514517 attn_o_b : Tensor < usize > ,
515518
516- ffn_norm_w : Tensor < usize > ,
517- ffn_norm_b : Tensor < usize > ,
518519 ffn_up_w : Tensor < usize > ,
519520 ffn_up_b : Tensor < usize > ,
520521 ffn_down_w : Tensor < usize > ,
521522 ffn_down_b : Tensor < usize > ,
522523
523- output_norm_w : Tensor < usize > ,
524- output_norm_b : Tensor < usize > ,
525- output_weight : Tensor < usize > ,
526- pos_embd : Tensor < usize > ,
527-
528524 weights : W ,
529525}
530526
531527impl Gpt2Meta {
532528 fn decorator < W > ( & self , weights : W ) -> WeightDecorator < W > {
533529 use crate :: TensorUsage :: Computation ;
534530 WeightDecorator {
535- attn_norm_w : self . attn_norm_w ( ) ,
536- attn_norm_b : self . attn_norm_b ( ) ,
531+ pos_embd : self . pos_embd ( ) ,
532+ output_weight : self . output_weight ( ) ,
533+ norm : self . norm ( ) ,
534+
537535 attn_qkv_w : self . attn_qkv_w ( Computation ) ,
538- attn_qkv_b : self . attn_qkv_b ( ) ,
536+ attn_qkv_b : self . attn_qkv_b ( Computation ) ,
539537 attn_o_w : self . attn_o_w ( Computation ) ,
540- attn_o_b : self . attn_o_b ( ) ,
538+ attn_o_b : self . attn_o_b ( Computation ) ,
541539
542- ffn_norm_w : self . ffn_norm_w ( ) ,
543- ffn_norm_b : self . ffn_norm_b ( ) ,
544540 ffn_up_w : self . ffn_up_w ( Computation ) ,
545- ffn_up_b : self . ffn_up_b ( ) ,
541+ ffn_up_b : self . ffn_up_b ( Computation ) ,
546542 ffn_down_w : self . ffn_down_w ( Computation ) ,
547- ffn_down_b : self . ffn_down_b ( ) ,
548-
549- output_norm_w : self . output_norm_w ( ) ,
550- output_norm_b : self . output_norm_b ( ) ,
551- output_weight : self . output_weight ( ) ,
552- pos_embd : self . pos_embd ( ) ,
543+ ffn_down_b : self . ffn_down_b ( Computation ) ,
553544
554545 weights,
555546 }
@@ -563,10 +554,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
563554 queue : & QueueOf < W :: Hardware > ,
564555 ) -> [ Tensor < W :: Memory < ' _ > > ; 2 ] {
565556 let [ w, b] = self . weights . load_blk ( BlkWeight :: AttnNorm , iblk, queue) ;
566- [
567- self . attn_norm_w . clone ( ) . map ( |_| w) ,
568- self . attn_norm_b . clone ( ) . map ( |_| b) ,
569- ]
557+ [ self . norm . clone ( ) . map ( |_| w) , self . norm . clone ( ) . map ( |_| b) ]
570558 }
571559
572560 pub fn attn_qkv (
@@ -595,10 +583,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
595583 queue : & QueueOf < W :: Hardware > ,
596584 ) -> [ Tensor < W :: Memory < ' _ > > ; 2 ] {
597585 let [ w, b] = self . weights . load_blk ( BlkWeight :: FfnNorm , iblk, queue) ;
598- [
599- self . ffn_norm_w . clone ( ) . map ( |_| w) ,
600- self . ffn_norm_b . clone ( ) . map ( |_| b) ,
601- ]
586+ [ self . norm . clone ( ) . map ( |_| w) , self . norm . clone ( ) . map ( |_| b) ]
602587 }
603588
604589 pub fn ffn_up ( & self , iblk : usize , queue : & QueueOf < W :: Hardware > ) -> [ Tensor < W :: Memory < ' _ > > ; 2 ] {
@@ -623,10 +608,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
623608
624609 pub fn output_norm ( & self , queue : & QueueOf < W :: Hardware > ) -> [ Tensor < W :: Memory < ' _ > > ; 2 ] {
625610 let [ w, b] = self . weights . output_norm ( queue) ;
626- [
627- self . output_norm_w . clone ( ) . map ( |_| w) ,
628- self . output_norm_b . clone ( ) . map ( |_| b) ,
629- ]
611+ [ self . norm . clone ( ) . map ( |_| w) , self . norm . clone ( ) . map ( |_| b) ]
630612 }
631613
632614 pub fn output_weight ( & self , queue : & QueueOf < W :: Hardware > ) -> Tensor < W :: Memory < ' _ > > {
0 commit comments