2020# testing utilities
2121from triton_kernels .testing import assert_close , compute_actual_scale
2222# target-specific utilities
23- from triton_kernels .target_info import is_hip , is_hip_cdna3 , is_cuda , is_hip_cdna4
23+ from triton_kernels .target_info import is_hip , is_hip_cdna3 , is_cuda , is_xpu , is_hip_cdna4
2424
2525# ---------------
2626# initialize data
@@ -73,7 +73,7 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
7373 if mode == 'batched' or (not has_y_gammas ) or (has_y_gammas and (gindx is not None ) and act_dtype .itemsize >= 2 ):
7474 gs0 = None
7575 gs1 = None
76- if "float8" in str (weight_dtype ) and torch .cuda .get_device_capability ()[0 ] < 10 :
76+ if is_cuda () and "float8" in str (weight_dtype ) and torch .cuda .get_device_capability ()[0 ] < 10 :
7777 w = w .transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
7878 return x , w , bias , gs0 , gs1
7979
@@ -294,6 +294,10 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
294294 if split_k > 1 :
295295 pytest .skip ("splitK hasn't been fully tested on AMD GPU." )
296296
297+ elif is_xpu ():
298+ if split_k > 1 :
299+ pytest .skip ("FIXME: https://github.com/intel/intel-xpu-backend-for-triton/issues/5074" )
300+
297301 if "float8_e4m3fnuz" in (weight_dtype_str , act_dtype_str ) and not is_hip_cdna3 ():
298302 pytest .skip ("float8_e4m3fnuz only tested on AMD CDNA3 Platform" )
299303
@@ -308,20 +312,21 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
308312 pytest .skip ("Non-scale swizzling not supported on CDNA4 yet" )
309313 if n % 32 != 0 or k % (32 * 8 ) != 0 :
310314 pytest .skip (f"Shape { m } x{ n } x{ k } is not supported for scale swizzling on AMD GPU" )
311- if torch .cuda .get_device_capability ()[0 ] < 9 :
312- pytest .skip ("NYI. Ampere swizzling." )
313- if torch .cuda .get_device_capability ()[0 ] < 10 :
314- if "mxfloat4" not in weight_dtype_str :
315- pytest .skip ("NYI. Hopper swizzling just implemented for mxfp4." )
316- if k % 64 != 0 or n % 64 != 0 :
317- # Automatic padding not implemented for Hopper swizzle
318- pytest .skip ("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles)." )
315+ if is_cuda ():
316+ if torch .cuda .get_device_capability ()[0 ] < 9 :
317+ pytest .skip ("NYI. Ampere swizzling." )
318+ if torch .cuda .get_device_capability ()[0 ] < 10 :
319+ if "mxfloat4" not in weight_dtype_str :
320+ pytest .skip ("NYI. Hopper swizzling just implemented for mxfp4." )
321+ if k % 64 != 0 or n % 64 != 0 :
322+ # Automatic padding not implemented for Hopper swizzle
323+ pytest .skip ("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles)." )
319324
320325 # launch metadata for batched / mx types may not work yet.
321326 torch .manual_seed (0 )
322327
323328 block_k = None
324- if is_persistent and weight_dtype_str .startswith ("mx" ) and torch .cuda .get_device_capability ()[0 ] < 10 :
329+ if is_cuda () and is_persistent and weight_dtype_str .startswith ("mx" ) and torch .cuda .get_device_capability ()[0 ] < 10 :
325330 # Override block_k for testing correctness. The default is temporarily 128 for
326331 # performance reasons which doesn't work with persistent matmul.
327332 # TODO: revisit when Triton is better for H100 + MXFP4
@@ -436,7 +441,7 @@ def round_x(x, idx):
436441
437442 round_y = lambda y : (y / y_scale ).to (act_dtype ).to (torch .float32 ) * y_scale if sep_scatter else y
438443 ref_y = matmul_ogs_torch (x_ref , w_ref , bias_ref , #
439- rdata , gindx , sindx , round_x = round_x , round_y = round_y , gammas = gs1_ref )
444+ rdata , gindx , sindx , round_x = round_x , round_y = round_y , gammas = gs1_ref , device = device )
440445 scale = lambda val , scal : val if scal is None else val / scal
441446 if n_expt_shards > 1 :
442447 if do_scatter :
@@ -549,21 +554,21 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
549554 (4096 , 4096 , 0 ),
550555])
551556@pytest .mark .parametrize ("view_x_as_zero_cols" , [False , True ])
552- def test_zero_reduction_dim (m , n , k , view_x_as_zero_cols ):
557+ def test_zero_reduction_dim (m , n , k , view_x_as_zero_cols , device ):
553558 torch .manual_seed (0 )
554559
555560 if view_x_as_zero_cols :
556- x = torch .randn (m , m , device = "cuda" , dtype = torch .bfloat16 )
561+ x = torch .randn (m , m , device = device , dtype = torch .bfloat16 )
557562 x = x [:0 , :].transpose (- 1 , - 2 )
558563 else :
559- x = torch .randn (m , k , device = "cuda" , dtype = torch .bfloat16 )
560- w = torch .randn (k , n , device = "cuda" , dtype = torch .bfloat16 )
561- bias = torch .randn (n , device = "cuda" , dtype = torch .float32 )
564+ x = torch .randn (m , k , device = device , dtype = torch .bfloat16 )
565+ w = torch .randn (k , n , device = device , dtype = torch .bfloat16 )
566+ bias = torch .randn (n , device = device , dtype = torch .float32 )
562567
563568 try :
564569 tri_y = matmul_ogs (x , w , bias )
565570 except opt_flags .InapplicableConstraint :
566571 pytest .skip ("inapplicable constraint" )
567- ref_y = matmul_ogs_torch (x , w , bias , round_x = lambda x , idx : x , round_y = lambda y : y )
572+ ref_y = matmul_ogs_torch (x , w , bias , round_x = lambda x , idx : x , round_y = lambda y : y , device = device )
568573
569574 assert_close (ref_y , tri_y )
0 commit comments