@@ -1217,7 +1217,15 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR
12171217 if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE ) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE ):
12181218 pytest .skip ("Float4 without scale is tested in test_block_scale_fp4" )
12191219 elif is_xpu ():
1220- pytest .xfail ("XPU does not natively support scaled mxfp8 & mxfp4 matmul" )
1220+ if not (WITH_A_SCALE and WITH_B_SCALE ):
1221+ pytest .skip ("None scale has not been tested on XPU backend" )
1222+ if not (A_DATA_TYPE == "float8e5" and B_DATA_TYPE == "float4" ):
1223+ pytest .skip (f"(A: { A_DATA_TYPE } , B: { B_DATA_TYPE } ) has not been tested on XPU backend" )
1224+ if (BLOCK_M , BLOCK_N ,
1225+ BLOCK_K ) == (128 , 256 ,
1226+ 256 ) and CONST_SCALE and triton .runtime .driver .active .utils .get_device_properties (
1227+ triton .runtime .driver .active .get_current_device ())["max_shared_mem" ] < 196608 :
1228+ pytest .skip ("XPU: Not enough shared memory" )
12211229 if not PACK_B_ALONG_K and B_DATA_TYPE != "float4" :
12221230 pytest .xfail ("Pack along K can only be False for float4" )
12231231
@@ -1288,6 +1296,8 @@ def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bo
12881296 kernel_kwargs = {}
12891297 if is_hip ():
12901298 kernel_kwargs ["matrix_instr_nonkdim" ] = nonKDim
1299+ if is_xpu () and (128 , 256 , 256 ) == (BLOCK_M , BLOCK_N , BLOCK_K ) and not CONST_SCALE and not PACK_B_ALONG_K :
1300+ kernel_kwargs ["num_warps" ] = 8
12911301 out = mxfp8_mxfp4_matmul [grid ](a , b , output , a_scale , b_scale , M , N , K , stride_scale , a .stride (0 ), a .stride (1 ),
12921302 b .stride (0 ), b .stride (1 ), output .stride (0 ), output .stride (1 ), not CONST_SCALE ,
12931303 dtype_converter [A_DATA_TYPE ], dtype_converter [B_DATA_TYPE ], BLOCK_M , BLOCK_N ,
0 commit comments