@@ -159,6 +159,9 @@ class Case:
159159 split_k : int = 1
160160 hbm_swizzling : bool = False
161161 epilogue_subtile : Union [int , None ] = None
162+ x_transpose : bool = False
163+ w_transpose : bool = False
164+ y_transpose : bool = False
162165
163166
164167@pytest .mark .parametrize (
@@ -252,6 +255,13 @@ class Case:
252255 Case (1000 , 400 , 400 , "ragged" , "float8_e4m3fn" , "float8_e4m3fn" , 3 , 1 ),
253256 Case (600 , 400 , 400 , "ragged" , "float8_e4m3fn" , "float8_e4m3fn" , 4 , 2 ),
254257 Case (600 , 400 , 400 , "ragged" , "float8_e4m3fn" , "float8_e4m3fn" , 4 , 2 , n_expt_shards = 2 ),
258+ ] + [
259+ Case (320 , 400 , 400 , mode , dtype , dtype , x_transpose = x_transpose , w_transpose = w_transpose , y_transpose = y_transpose )
260+ for mode in ("batched" , "ragged" )
261+ for dtype in ("float16" , "float8_e5m2" )
262+ for x_transpose in (False , True )
263+ for w_transpose in (False , True )
264+ for y_transpose in (False , True )
255265 ]
256266 ],
257267)
@@ -268,6 +278,7 @@ class Case:
268278@pytest .mark .parametrize ("is_persistent" , [False , True ])
269279def test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , has_y_gammas , is_persistent , n_expts_tot ,
270280 n_expts_act , n_expt_shards , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , epilogue_subtile ,
281+ x_transpose , w_transpose , y_transpose ,
271282 device , opt_flags_scope , fresh_knobs ):
272283 # TODO: remove when Triton FP8 supports proper RTNE
273284 if is_cuda ():
@@ -373,6 +384,17 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
373384 has_y_gammas , requires_grad = test_bwd , device = device )
374385 x_ref , w_ref , bias_ref , gs0_ref , gs1_ref = apply_precision (x_tri , w_tri , bias_tri , gs0_tri , gs1_tri , precision_opt )
375386
387+ if x_transpose :
388+ x_tri = x_tri .detach ().transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 ).requires_grad_ (test_bwd )
389+ if w_transpose :
390+ w_tri = w_tri .detach ().transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 ).requires_grad_ (test_bwd )
391+ if y_transpose :
392+ n_rows = m if gindx is None else gindx .dst_indx .shape [0 ]
393+ yT_shape = (n_expts_tot , n , n_rows ) if mode == "batched" else (n , n_rows )
394+ y_tri_in = torch .empty (yT_shape , dtype = act_dtype , device = device ).transpose (- 1 , - 2 )
395+ else :
396+ y_tri_in = None
397+
376398 if w_tri .shape [0 ] == 1 and mode != "batched" :
377399 # Test the case when weight has dim 2, i.e., shape (K, N).
378400 w_tri = w_tri .squeeze (0 ).detach ().requires_grad_ (test_bwd )
@@ -423,9 +445,14 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
423445
424446 # triton
425447 try :
426- tri_y = matmul_ogs (x_tri , w_tri , bias_tri , rdata , gindx , sindx , precision_opt , gammas = gs1_ref , epilogue = epilogue )
448+ tri_y = matmul_ogs (x_tri , w_tri , bias_tri , rdata , gindx , sindx , precision_opt ,
449+ gammas = gs1_ref , epilogue = epilogue , y = y_tri_in )
427450 except (opt_flags .InapplicableConstraint , NotImplementedError ):
428451 pytest .xfail ("inapplicable opt_flags constraint" )
452+ if y_tri_in is not None :
453+ assert tri_y .data_ptr () == y_tri_in .data_ptr ()
454+ assert tri_y .shape == y_tri_in .shape
455+ assert tri_y .stride () == y_tri_in .stride ()
429456 # If split_k > 1, then the intermediate tensor is fp32.
430457 sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1
431458 sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
@@ -537,7 +564,7 @@ def test_set_idle_sms():
537564 num_idle_sms = 24
538565 matmul_ogs_set_idle_sms (num_idle_sms )
539566 flags = make_opt_flags (torch .float32 , torch .float32 , torch .float32 , PrecisionConfig (), \
540- 1 , 1024 , 1024 , 1024 , None , True , False , 1 )
567+ 1 , 1024 , 1024 , 1024 , None , True , False , 1 , False )
541568 assert flags .idle_sms == num_idle_sms
542569
543570
0 commit comments