@@ -69,6 +69,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
6969 x_dtype = torch .float8_e4m3fnuz
7070
7171 input_x = torch .randn ((batch // DP , dim1 ), device = dev )
72+ expt_assignment = triton_dist .create_expt_assignment (EP , n_expts_tot , torch .device (dev ))
7273 # run layer
7374 fpath = Path (tempfile .mktemp ())
7475 proton .start (str (fpath ), hook = "triton" )
@@ -78,15 +79,15 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
7879 if n_expts_tot > 1 : # sparse
7980 logits = matmul_ogs (xg , wg , bg , precision_config = pcg )
8081 x , rdata , gather_indx , scatter_indx , metadata = triton_dist .routing (input_x , logits , n_expts_act , EP = EP ,
81- TP = TP )
82+ TP = TP , expt_assignment = expt_assignment )
8283 else : # dense
8384 x = triton_dist .all_gather (input_x , dim = 0 )
8485 rdata , gather_indx , scatter_indx , metadata = None , None , None , None
8586 if x .nelement () > 0 :
8687 x = matmul_ogs (x , w1 , b1 , rdata , gather_indx = gather_indx , precision_config = pc1 , fused_activation = act )
8788 x = matmul_ogs (x , w2 , b2 if rank % TP == 0 else None , rdata , scatter_indx = scatter_indx ,
8889 precision_config = pc2 )
89- x = triton_dist .reduce_scatter (x , metadata = metadata , dim = 0 )
90+ x = triton_dist .reduce_scatter (x , n_expts_act , metadata = metadata , expt_assignment = expt_assignment )
9091 proton .finalize ()
9192 return roofline .parse_profile (fpath .with_suffix (".hatchet" ), useful_op_regex = ".*matmul.*" )
9293
@@ -136,6 +137,8 @@ def roofline_mlp(batch_sizes, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
136137 parser .add_argument ("--name" , type = str , choices = ["dense" , "gpt-oss-x2" ])
137138 parser .add_argument ("--quantized" , action = "store_true" , default = False )
138139 args = parser .parse_args ()
140+ if args .tp > 1 :
141+ raise NotImplementedError ("TP>1 is not supported yet in distributed mode." )
139142 dtypes = quantized_dtypes if args .quantized else dense_dtypes
140143 if args .name == "dense" :
141144 assert args .ep == 1 , "EP must be 1 for dense"
0 commit comments