@@ -39,9 +39,8 @@ def mask_indx(idx, n_expts_act):
3939 return idx
4040
4141
42- def init_routing_data (m , n_expts_tot , n_expts_act , n_expt_shards , do_gather , do_scatter ):
43- dev = "cuda"
44- logits = torch .randn ((m , n_expts_tot ), dtype = torch .float16 , device = dev , requires_grad = True )
42+ def init_routing_data (m , n_expts_tot , n_expts_act , n_expt_shards , do_gather , do_scatter , device = "cuda" ):
43+ logits = torch .randn ((m , n_expts_tot ), dtype = torch .float16 , device = device , requires_grad = True )
4544 routing_data , gather_idx , scatter_idx = routing (logits , n_expts_act , simulated_ep = n_expt_shards )
4645 routing_data .gate_scal = None
4746 gather_idx = gather_idx if do_gather else None
@@ -50,17 +49,18 @@ def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_
5049
5150
5251def init_compute_data (m , n , k , gindx , sindx , n_expts_tot , n_expts_act , n_expt_shards , mode , act_dtype , weight_dtype ,
53- has_y_gammas , requires_grad = True , dev = "cuda" ):
52+ has_y_gammas , requires_grad = True , device = "cuda" ):
5453 torch .manual_seed (0 )
5554 assert mode in {'batched' , 'ragged' }
5655 in_m = m * (n_expts_act if gindx is None else 1 )
5756 out_m = m * (n_expts_act if sindx is None else 1 )
5857 shape_x = (n_expts_tot , in_m , k ) if mode == 'batched' else (in_m , k )
59- x = alloc_rand (shape_x , device = dev , dtype = act_dtype , requires_grad = requires_grad )
60- w = alloc_rand ((n_expts_tot // n_expt_shards , k , n ), device = dev , dtype = weight_dtype , requires_grad = requires_grad )
61- bias = alloc_rand ((n_expts_tot // n_expt_shards , n ), device = dev , dtype = torch .float32 , requires_grad = requires_grad )
62- gs0 = 2 ** torch .randint (- 5 , 0 , (m * n_expts_act , ), device = dev , dtype = torch .float32 , requires_grad = requires_grad )
63- gs1 = 2 ** torch .randint (- 5 , 0 , (m * n_expts_act , ), device = dev , dtype = torch .float32 , requires_grad = requires_grad )
58+ x = alloc_rand (shape_x , device = device , dtype = act_dtype , requires_grad = requires_grad )
59+ w = alloc_rand ((n_expts_tot // n_expt_shards , k , n ), device = device , dtype = weight_dtype , requires_grad = requires_grad )
60+ bias = alloc_rand ((n_expts_tot // n_expt_shards , n ), device = device , dtype = torch .float32 ,
61+ requires_grad = requires_grad )
62+ gs0 = 2 ** torch .randint (- 5 , 0 , (m * n_expts_act , ), device = device , dtype = torch .float32 , requires_grad = requires_grad )
63+ gs1 = 2 ** torch .randint (- 5 , 0 , (m * n_expts_act , ), device = device , dtype = torch .float32 , requires_grad = requires_grad )
6464 gs0 = gs0 .detach ().requires_grad_ (requires_grad )
6565 gs1 = gs1 .detach ().requires_grad_ (requires_grad )
6666 if mode == 'batched' or (not has_y_gammas ) or (has_y_gammas and (gindx is not None ) and act_dtype .itemsize >= 2 ):
@@ -75,12 +75,13 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
7575# ---------------
7676
7777
78- def init_precision (out_dtype , act_use_flexpoint , weight_use_flexpoint , n_expts_tot = 1 , mx_ctx = MicroscalingCtx ()):
78+ def init_precision (out_dtype , act_use_flexpoint , weight_use_flexpoint , n_expts_tot = 1 , mx_ctx = MicroscalingCtx (),
79+ device = "cuda" ):
7980 # flexpoint
8081 make_tensor = lambda val0 , val1 : torch .tensor ([val0 , val1 ] * (n_expts_tot // 2 ) +
8182 ([val0 ]
82- if n_expts_tot % 2 else []), dtype = torch .float32 , device = "cuda" )
83- make_scalar = lambda val : torch .tensor ([val ], dtype = torch .float32 , device = "cuda" )
83+ if n_expts_tot % 2 else []), dtype = torch .float32 , device = device )
84+ make_scalar = lambda val : torch .tensor ([val ], dtype = torch .float32 , device = device )
8485 in_flex_data = lambda scale , use_flex : InFlexData (dtype = torch .float8_e5m2 , scale = make_scalar (scale )
8586 ) if use_flex else InFlexData ()
8687 in_flex_edata = lambda scale0 , scale1 , use_flex : InFlexData (dtype = torch .float8_e5m2 , scale = make_tensor (
@@ -211,7 +212,7 @@ class Case:
211212@pytest .mark .parametrize ("has_y_gammas" , [False , True ])
212213@pytest .mark .parametrize ("is_persistent" , [False , True ])
213214def test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , has_y_gammas , is_persistent , n_expts_tot ,
214- n_expts_act , n_expt_shards , mode , act_dtype_str , weight_dtype_str , block_m , swizzle_mx_scale ):
215+ n_expts_act , n_expt_shards , mode , act_dtype_str , weight_dtype_str , block_m , swizzle_mx_scale , device ):
215216 # TODO: remove when Triton FP8 supports proper RTNE
216217 if "float8" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] < 9 :
217218 pytest .skip ("Float8 not tested on A100" )
@@ -254,16 +255,17 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
254255 act_is_float8 = act_dtype .itemsize == 1
255256 weight_is_float8 = weight_dtype .itemsize == 1
256257 precision_opt = init_precision (act_dtype , act_is_float8 , weight_is_float8 and not is_mixed_input ,
257- n_expts_tot // n_expt_shards )
258+ n_expts_tot // n_expt_shards , device = device )
258259 # precision_opt.x_pad_trans_requires_flexpoint = False
259260 if mode == "ragged" :
260- m , rdata , gindx , sindx = init_routing_data (m , n_expts_tot , n_expts_act , n_expt_shards , do_gather , do_scatter )
261+ m , rdata , gindx , sindx = init_routing_data (m , n_expts_tot , n_expts_act , n_expt_shards , do_gather , do_scatter ,
262+ device = device )
261263 else :
262264 rdata = gindx = sindx = None
263265 x_tri , w_tri , bias_tri , gs0_tri , gs1_tri = init_compute_data (m , n , k , gindx , sindx , n_expts_tot , n_expts_act ,
264266 n_expt_shards , mode , act_dtype , #
265267 torch .bfloat16 if is_mixed_input else weight_dtype ,
266- has_y_gammas , requires_grad = test_bwd )
268+ has_y_gammas , requires_grad = test_bwd , device = device )
267269 x_ref , w_ref , bias_ref , gs0_ref , gs1_ref = apply_precision (x_tri , w_tri , bias_tri , gs0_tri , gs1_tri , precision_opt )
268270
269271 if is_mixed_input :
0 commit comments