@@ -45,19 +45,16 @@ def mask_indx(idx, n_expts_act):
4545 return idx
4646
4747
48- def init_routing_data (m , n_expts_tot , n_expts_act , n_expt_shards , do_gather , do_scatter , device = "cuda" ):
48+ def init_routing_data (m , n_expts_tot , n_expts_act , do_gather , do_scatter , device = "cuda" ):
4949 logits = torch .randn ((m , n_expts_tot ), dtype = torch .float16 , device = device , requires_grad = True )
50- routing_data , gather_idx , scatter_idx = routing (logits , n_expts_act , simulated_ep = n_expt_shards )
50+ routing_data , gather_idx , scatter_idx = routing (logits , n_expts_act )
5151 routing_data .gate_scal = None
5252 gather_idx = gather_idx if do_gather else None
5353 scatter_idx = scatter_idx if do_scatter else None
54- # TODO: re-enable
55- # if do_gather and do_scatter and n_expts_act == 1 and n_expt_shards == 1:
56- # scatter_idx = mask_indx(scatter_idx, n_expts_act)
5754 return m , routing_data , gather_idx , scatter_idx
5855
5956
60- def init_compute_data (m , n , k , rdata , gindx , sindx , n_expts_tot , n_expts_act , n_expt_shards , mode , act_dtype , weight_dtype ,
57+ def init_compute_data (m , n , k , rdata , gindx , sindx , n_expts_tot , n_expts_act , mode , act_dtype , weight_dtype ,
6158 has_y_gammas , requires_grad = True , device = "cuda" ,
6259 inner_expt_opt = None , padding_block_k = None ):
6360 torch .manual_seed (0 )
@@ -70,7 +67,7 @@ def init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, n_
7067 else :
7168 in_m = m * (n_expts_act if gindx is None else 1 )
7269 shape_x = (n_expts_tot , in_m , k ) if mode == 'batched' else (in_m , k )
73- shape_batch = tuple () if (mode == "plain" or inner_expt_opt is not None ) else (n_expts_tot // n_expt_shards , )
70+ shape_batch = tuple () if (mode == "plain" or inner_expt_opt is not None ) else (n_expts_tot , )
7471 x = alloc_rand (shape_x , device = device , dtype = act_dtype , requires_grad = requires_grad )
7572 w = alloc_rand (shape_batch + (k , n ), device = device , dtype = weight_dtype , requires_grad = requires_grad )
7673 bias = alloc_rand (shape_batch + (n , ), device = device , dtype = torch .float32 , requires_grad = requires_grad )
@@ -194,7 +191,6 @@ class Case:
194191 weight_dtype_str : str
195192 n_expts_tot : int = 1
196193 n_expts_act : int = 1
197- n_expt_shards : int = 1
198194 split_k : int = 1
199195 hbm_swizzling : bool = False
200196 epilogue_subtile : Union [int , None ] = None
@@ -216,10 +212,6 @@ class Case:
216212 Case (5 , 7 , 0 , "batched" , "float16" , "float16" ),
217213 # Non-mx types:
218214 Case (16 , 256 , 256 , "ragged" , "float16" , "float16" , 128 , 4 ),
219- Case (16 , 256 , 256 , "ragged" , "float16" , "float16" , 128 , 4 , n_expt_shards = 2 ),
220- Case (16 , 256 , 256 , "ragged" , "float16" , "float16" , 128 , 4 , n_expt_shards = 4 ),
221- Case (400 , 300 , 500 , "ragged" , "float16" , "float16" , 32 , 4 , n_expt_shards = 4 ),
222- Case (16 , 256 , 256 , "ragged" , "float16" , "float16" , 4 , 1 , n_expt_shards = 2 ),
223215 Case (16 , 256 , 256 , "ragged" , "float16" , "float16" , 128 , 4 , split_k = 3 ),
224216 Case (16 , 256 , 256 , "ragged" , "float16" , "float16" , 128 , 4 , split_k = 3 ),
225217 Case (300 , 400 , 400 , "batched" , "float8_e5m2" , "float8_e5m2" , 5 , 1 ),
@@ -235,8 +227,6 @@ class Case:
235227 Case (600 , 400 , 400 , "ragged" , "float8_e5m2" , "float8_e5m2" , 4 , 2 , epilogue_subtile = 2 ),
236228 Case (600 , 400 , 400 , "ragged" , "float8_e5m2" , "float8_e5m2" , 4 , 2 , epilogue_subtile = 4 ),
237229 Case (600 , 400 , 400 , "ragged" , "float8_e5m2" , "float8_e5m2" , 4 , 2 ),
238- Case (600 , 400 , 400 , "ragged" , "float8_e5m2" , "float8_e5m2" , 4 , 2 , n_expt_shards = 2 ),
239- Case (600 , 400 , 400 , "ragged" , "float8_e5m2" , "float8_e5m2" , 4 , 1 , n_expt_shards = 2 ),
240230 Case (600 , 400 , 400 , "ragged" , "float8_e5m2" , "float8_e5m2" , 4 , 2 , split_k = 2 ),
241231 Case (1000 , 400 , 400 , "ragged" , "float16" , "float16" , 3 , 1 ),
242232 Case (1000 , 700 , 700 , "ragged" , "float16" , "float16" , 8 , 2 ),
@@ -291,19 +281,17 @@ class Case:
291281 Case (300 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" ),
292282 Case (1000 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" , 3 , 1 ),
293283 Case (600 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" , 4 , 2 ),
294- Case (600 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" , 4 , 2 , n_expt_shards = 2 ),
295284 Case (600 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" , 4 , 2 , split_k = 2 ),
296285 Case (300 , 400 , 400 , "ragged" , "float8_e4m3fn" , "float8_e4m3fn" ),
297286 Case (1000 , 400 , 400 , "ragged" , "float8_e4m3fn" , "float8_e4m3fn" , 3 , 1 ),
298287 Case (600 , 400 , 400 , "ragged" , "float8_e4m3fn" , "float8_e4m3fn" , 4 , 2 ),
299- Case (600 , 400 , 400 , "ragged" , "float8_e4m3fn" , "float8_e4m3fn" , 4 , 2 , n_expt_shards = 2 ),
300288 ] + [
301- Case (320 , 400 , 400 , mode , dtype , dtype , n_expts_tot , n_expts_act , n_expt_shards = n_expt_shards ,
289+ Case (320 , 400 , 400 , mode , dtype , dtype , n_expts_tot , n_expts_act ,
302290 x_transpose = x_transpose , w_transpose = w_transpose , y_transpose = y_transpose )
303- for (mode , n_expts_tot , n_expts_act , n_expt_shards ) in (
304- ("batched" , 1 , 1 , 1 ),
305- ("ragged" , 8 , 4 , 1 ),
306- ("ragged" , 32 , 4 , 4 ),
291+ for (mode , n_expts_tot , n_expts_act ) in (
292+ ("batched" , 1 , 1 ),
293+ ("ragged" , 8 , 4 ),
294+ ("ragged" , 32 , 4 ),
307295 )
308296 for dtype in ("float16" , "float8_e5m2" )
309297 for x_transpose in (False , True )
@@ -326,7 +314,7 @@ class Case:
326314@pytest .mark .parametrize ("has_y_gammas" , [False , True ])
327315@pytest .mark .parametrize ("is_persistent" , [False , True ])
328316def test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
329- n_expts_act , n_expt_shards , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , epilogue_subtile ,
317+ n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , epilogue_subtile ,
330318 x_transpose , w_transpose , y_transpose ,
331319 device , opt_flags_scope , fresh_knobs ):
332320 # TODO: remove when Triton FP8 supports proper RTNE
@@ -424,17 +412,17 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
424412 weight_dtype = dtype_str_to_torch (weight_dtype_str )
425413 act_dtype = dtype_str_to_torch (act_dtype_str )
426414 precision_opt = init_precision (act_dtype , act_is_float8 , weight_dtype , weight_mxfp ,
427- n_expts_tot // n_expt_shards , expt_is_inner , device = device )
415+ n_expts_tot , expt_is_inner , device = device )
428416 # precision_opt.x_pad_trans_requires_flexpoint = False
429417 if mode == "ragged" :
430- m , rdata , gindx , sindx = init_routing_data (m , n_expts_tot , n_expts_act , n_expt_shards , do_gather , do_scatter ,
418+ m , rdata , gindx , sindx = init_routing_data (m , n_expts_tot , n_expts_act , do_gather , do_scatter ,
431419 device = device )
432420 else :
433421 rdata = gindx = sindx = None
434422
435423 padding_block_k = 32
436424 x_tri , w_tri , bias_tri , gs0_tri , gs1_tri = init_compute_data (m , n , k , rdata , gindx , sindx , n_expts_tot , n_expts_act ,
437- n_expt_shards , mode , torch .bfloat16 if act_mxfp8 else act_dtype , #
425+ mode , torch .bfloat16 if act_mxfp8 else act_dtype , #
438426 torch .bfloat16 if weight_mxfp else weight_dtype ,
439427 has_y_gammas , requires_grad = test_bwd , device = device ,
440428 inner_expt_opt = inner_expt_opt , padding_block_k = padding_block_k )
@@ -446,9 +434,9 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
446434 w_tri = w_tri .detach ().transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 ).requires_grad_ (test_bwd )
447435 if y_transpose :
448436 if mode == "batched" :
449- yT_shape = (n_expts_tot // n_expt_shards , n , x_tri .shape [- 2 ])
437+ yT_shape = (n_expts_tot , n , x_tri .shape [- 2 ])
450438 elif expt_is_inner :
451- yT_shape = (n_expts_tot // n_expt_shards , n , k )
439+ yT_shape = (n_expts_tot , n , k )
452440 elif sindx is not None :
453441 yT_shape = (n , m )
454442 else :
@@ -549,20 +537,6 @@ def scale(val, scal):
549537 assert val .ndim == 3
550538 return val / scal [:, None , None ]
551539
552- if n_expt_shards > 1 :
553- if do_scatter :
554- indx = sindx .dst_indx [sindx .dst_indx != - 1 ]
555- ref_y = ref_y [indx // n_expts_act , :]
556- if act_is_float8 :
557- tri_y = tri_y .view (torch .int8 )
558- tri_y = tri_y [indx // n_expts_act , :]
559- if act_is_float8 :
560- tri_y = tri_y .view (act_dtype )
561- elif not expt_is_inner :
562- n_rows = rdata .expt_hist .sum ()
563- assert n_rows > 0
564- ref_y = ref_y [:n_rows ]
565- tri_y = tri_y [:n_rows ]
566540 if act_mxfp8 :
567541 tri_y = upcast_from_mxfp (tri_y , precision_opt .out_scale , target_dtype = torch .bfloat16 , axis = - 1 ).to (ref_y .dtype )
568542 ref_y_quant , ref_y_scale = downcast_to_mxfp_torch (ref_y , act_dtype , axis = - 1 )
@@ -683,18 +657,18 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
683657 "split_k" : split_k ,
684658 "fused_scatter" : fused_scatter ,
685659 }
686- n_expts_tot , n_expts_act , n_expt_shards = 1 , 1 , 1
660+ n_expts_tot , n_expts_act = 1 , 1
687661 opt_flags .update_opt_flags_constraints (constraints )
688662
689663 weight_dtype , act_dtype = torch .float16 , torch .float16
690664 if mode == "ragged" :
691- m , rdata , gindx , sindx = init_routing_data (m , n_expts_tot , n_expts_act , n_expt_shards , do_gather , do_scatter ,
665+ m , rdata , gindx , sindx = init_routing_data (m , n_expts_tot , n_expts_act , do_gather , do_scatter ,
692666 device = device )
693667 else :
694668 rdata = gindx = sindx = None
695669
696- precision_opt = init_precision (act_dtype , str (act_dtype ).startswith ("torch.float8" ), weight_dtype , False , n_expts_tot // n_expt_shards , device = device )
697- x , w , bias , _ , _ = init_compute_data (m , n , k , rdata , gindx , sindx , n_expts_tot , n_expts_act , n_expt_shards , mode ,
670+ precision_opt = init_precision (act_dtype , str (act_dtype ).startswith ("torch.float8" ), weight_dtype , False , n_expts_tot , device = device )
671+ x , w , bias , _ , _ = init_compute_data (m , n , k , rdata , gindx , sindx , n_expts_tot , n_expts_act , mode ,
698672 act_dtype , weight_dtype , False , requires_grad = False , device = device )
699673
700674 if mode == "batched" :
0 commit comments