Skip to content

Commit eff6a02

Browse files
Reenable some mxfp tests for XPU (#5379)
Addresses #4062 * skips same as CUDA * skip large block tests for PVC * add num_warps=8 for empty kernels case
1 parent 5f7ccc1 commit eff6a02

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1217,7 +1217,15 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR
12171217
if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE):
12181218
pytest.skip("Float4 without scale is tested in test_block_scale_fp4")
12191219
elif is_xpu():
1220-
pytest.xfail("XPU does not natively support scaled mxfp8 & mxfp4 matmul")
1220+
if not (WITH_A_SCALE and WITH_B_SCALE):
1221+
pytest.skip("None scale has not been tested on XPU backend")
1222+
if not (A_DATA_TYPE == "float8e5" and B_DATA_TYPE == "float4"):
1223+
pytest.skip(f"(A: {A_DATA_TYPE}, B: {B_DATA_TYPE}) has not been tested on XPU backend")
1224+
if (BLOCK_M, BLOCK_N,
1225+
BLOCK_K) == (128, 256,
1226+
256) and CONST_SCALE and triton.runtime.driver.active.utils.get_device_properties(
1227+
triton.runtime.driver.active.get_current_device())["max_shared_mem"] < 196608:
1228+
pytest.skip("XPU: Not enough shared memory")
12211229
if not PACK_B_ALONG_K and B_DATA_TYPE != "float4":
12221230
pytest.xfail("Pack along K can only be False for float4")
12231231

@@ -1288,6 +1296,8 @@ def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bo
12881296
kernel_kwargs = {}
12891297
if is_hip():
12901298
kernel_kwargs["matrix_instr_nonkdim"] = nonKDim
1299+
if is_xpu() and (128, 256, 256) == (BLOCK_M, BLOCK_N, BLOCK_K) and not CONST_SCALE and not PACK_B_ALONG_K:
1300+
kernel_kwargs["num_warps"] = 8
12911301
out = mxfp8_mxfp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, stride_scale, a.stride(0), a.stride(1),
12921302
b.stride(0), b.stride(1), output.stride(0), output.stride(1), not CONST_SCALE,
12931303
dtype_converter[A_DATA_TYPE], dtype_converter[B_DATA_TYPE], BLOCK_M, BLOCK_N,

scripts/test-triton.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ run_mxfp_tests() {
402402
cd $TRITON_PROJ/python/test/unit
403403

404404
TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=mxfp \
405-
run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu intel/test_mxfp_matmul.py
405+
run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu intel/test_mxfp_matmul.py language/test_matmul.py::test_mxfp8_mxfp4_matmul
406406
}
407407

408408
run_scaled_dot_tests() {

0 commit comments

Comments
 (0)