@@ -114,7 +114,7 @@ def _apply_padding_and_fill_unused_part_with_nan(t, is_padded):
114114# ---------------
115115
116116
117- def init_precision (out_dtype , act_use_flexpoint , weight_dtype , weight_mxfp , n_expts_tot = 1 , expt_is_inner = False , device = "cuda" ):
117+ def init_precision (out_dtype , act_use_flexpoint , weight_dtype , weight_mxfp , mode , n_expts_tot = 1 , expt_is_inner = False , device = "cuda" ):
118118 weight_use_flexpoint = weight_dtype .itemsize == 1 and not weight_mxfp
119119 # flexpoint
120120 make_tensor = lambda val0 , val1 : torch .tensor ([val0 , val1 ] * (n_expts_tot // 2 ) +
@@ -133,8 +133,8 @@ def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, n_ex
133133 ) if weight_use_flexpoint else InFlexData (),
134134 out_data = OutFlexData (
135135 dtype = out_dtype ,
136- expected_scale = make (4.00 , 5.00 , expt_is_inner ),
137- actual_scale = make (0 , 0 , expt_is_inner ),
136+ expected_scale = make (4.00 , 5.00 , mode == "batched" or expt_is_inner ),
137+ actual_scale = make (0 , 0 , mode == "batched" or expt_is_inner ),
138138 checksum_scale = None ,
139139 ) if act_use_flexpoint else OutFlexData (),
140140 )
@@ -233,6 +233,7 @@ class Case:
233233 Case (1000 , 700 , 700 , "ragged" , "float16" , "float16" , 8 , 2 , split_k = 9 ),
234234 Case (16 , 16 , 1000 , "batched" , "float16" , "float16" , 5 , 1 , split_k = None ),
235235 Case (16 , 16 , 1000 , "batched" , "float8_e5m2" , "float8_e5m2" , 5 , 1 , split_k = None ),
236+ Case (16 , 16 , 2048 , "batched" , "float8_e5m2" , "float8_e5m2" , 6 , 1 , split_k = 5 ),
236237 # mx types:
237238 Case (16 , 256 , 256 , "plain" , "bfloat16" , "mxfloat4_e2m1" , 1 , 1 ),
238239 Case (16 , 256 , 256 , "plain" , "bfloat16" , "mxfloat4_e2m1" , 1 , 1 , hbm_swizzling = True ),
@@ -412,7 +413,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
412413 weight_dtype = dtype_str_to_torch (weight_dtype_str )
413414 act_dtype = dtype_str_to_torch (act_dtype_str )
414415 precision_opt = init_precision (act_dtype , act_is_float8 , weight_dtype , weight_mxfp ,
415- n_expts_tot , expt_is_inner , device = device )
416+ mode , n_expts_tot , expt_is_inner , device = device )
416417 # precision_opt.x_pad_trans_requires_flexpoint = False
417418 if mode == "ragged" :
418419 m , rdata , gindx , sindx = init_routing_data (m , n_expts_tot , n_expts_act , do_gather , do_scatter ,
@@ -667,7 +668,7 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
667668 else :
668669 rdata = gindx = sindx = None
669670
670- precision_opt = init_precision (act_dtype , str (act_dtype ).startswith ("torch.float8" ), weight_dtype , False , n_expts_tot , device = device )
671+ precision_opt = init_precision (act_dtype , str (act_dtype ).startswith ("torch.float8" ), weight_dtype , False , mode , n_expts_tot , device = device )
671672 x , w , bias , _ , _ = init_compute_data (m , n , k , rdata , gindx , sindx , n_expts_tot , n_expts_act , mode ,
672673 act_dtype , weight_dtype , False , requires_grad = False , device = device )
673674
0 commit comments