22use gguf:: ggml_quants:: digit_layout:: { types as ty, DigitLayout } ;
33use itertools:: izip;
44use operators:: {
5+ all_reduce:: { AllReduce , ReduceOp } ,
56 attention_kv_cached:: AttnKVCached ,
67 mat_mul:: MatMul ,
78 mlp:: Mlp ,
89 rearrange:: Rearrange ,
910 rms_norm:: RmsNorm ,
1011 rope:: { Rope , Seq } ,
11- ByteOf , Hardware , LaunchError , Operator , QueueAlloc , QueueOf , Workspace ,
12+ ByteOf , Hardware , LaunchError , Operator , QueueAlloc , QueueOf , TopoNode , Workspace ,
1213} ;
1314use std:: ops:: { Deref , DerefMut } ;
1415use tensor:: { dt_size, split, Tensor } ;
1516
1617pub trait Operators {
1718 type Hardware : Hardware ;
19+ type TopoNode : TopoNode < Self :: Hardware > ;
1820 type RmsNorm : RmsNorm < Self :: Hardware > ;
1921 type MatMul : MatMul < Self :: Hardware > ;
2022 type Rope : Rope < Self :: Hardware > ;
2123 type AttnKVCached : AttnKVCached < Self :: Hardware > ;
2224 type Mlp : Mlp < Self :: Hardware > ;
2325 type Rearrange : Rearrange < Self :: Hardware > ;
26+ type AllReduce : AllReduce < Self :: Hardware , Self :: TopoNode > ;
2427
2528 fn debug < T > ( tensor : & Tensor < T > )
2629 where
@@ -73,19 +76,21 @@ pub struct LlamaWorker<Ops: Operators, W> {
7376 attn_kv_cached : Ops :: AttnKVCached ,
7477 mlp : Ops :: Mlp ,
7578 rearrange : Ops :: Rearrange ,
79+ all_reduce : Ops :: AllReduce ,
7680}
7781
7882impl < Ops : Operators , W > LlamaWorker < Ops , W > {
79- pub fn new ( processor : & Ops :: Hardware , meta : LlamaMeta , weights : W ) -> Self {
83+ pub fn new ( node : & Ops :: TopoNode , meta : LlamaMeta , weights : W ) -> Self {
8084 Self {
8185 weights : meta. decorator ( weights) ,
8286 meta,
83- rms_norm : Ops :: RmsNorm :: new ( processor) ,
84- mat_mul : Ops :: MatMul :: new ( processor) ,
85- rope : Ops :: Rope :: new ( processor) ,
86- attn_kv_cached : Ops :: AttnKVCached :: new ( processor) ,
87- mlp : Ops :: Mlp :: new ( processor) ,
88- rearrange : Ops :: Rearrange :: new ( processor) ,
87+ rms_norm : Ops :: RmsNorm :: new ( node. processor ( ) ) ,
88+ mat_mul : Ops :: MatMul :: new ( node. processor ( ) ) ,
89+ rope : Ops :: Rope :: new ( node. processor ( ) ) ,
90+ attn_kv_cached : Ops :: AttnKVCached :: new ( node. processor ( ) ) ,
91+ mlp : Ops :: Mlp :: new ( node. processor ( ) ) ,
92+ rearrange : Ops :: Rearrange :: new ( node. processor ( ) ) ,
93+ all_reduce : Ops :: AllReduce :: new ( node) ,
8994 }
9095 }
9196
@@ -240,7 +245,7 @@ where
240245 self . mat_mul ( & mut x, 1. , & x1, & w, 1. , workspace, queue_alloc) ?;
241246
242247 if distribute > 1 {
243- todo ! ( "all reduce" )
248+ self . all_reduce ( & mut x , workspace , queue_alloc ) ? ;
244249 }
245250
246251 let w = self . weights . ffn_norm ( iblk, queue) ;
@@ -250,7 +255,7 @@ where
250255 self . mlp ( & mut x, & x1, iblk, mlp_alpha, true , workspace, queue_alloc) ?;
251256
252257 if distribute > 1 {
253- todo ! ( "all reduce" )
258+ self . all_reduce ( & mut x , workspace , queue_alloc ) ? ;
254259 }
255260 }
256261
@@ -483,6 +488,29 @@ where
483488 queue_alloc,
484489 )
485490 }
491+
492+ fn all_reduce < X , QA > (
493+ & self ,
494+ x : & mut Tensor < X > ,
495+ workspace : & mut [ ByteOf < Ops :: Hardware > ] ,
496+ queue_alloc : & QA ,
497+ ) -> Result < ( ) , LaunchError >
498+ where
499+ X : DerefMut < Target = [ ByteOf < Ops :: Hardware > ] > ,
500+ QA : QueueAlloc < Hardware = Ops :: Hardware > ,
501+ {
502+ self . all_reduce . launch (
503+ & operators:: all_reduce:: Args {
504+ dst_layout : x. layout ( ) ,
505+ dst_base : x. base_mut ( ) ,
506+ src_layout : x. layout ( ) ,
507+ src_base : x. base ( ) ,
508+ op : ReduceOp :: Sum ,
509+ } ,
510+ workspace,
511+ queue_alloc,
512+ )
513+ }
486514}
487515
488516struct WeightDecorator < W > {
0 commit comments