Skip to content
12 changes: 11 additions & 1 deletion python/test/unit/language/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,15 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR
if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE):
pytest.skip("Float4 without scale is tested in test_block_scale_fp4")
elif is_xpu():
pytest.xfail("XPU does not natively support scaled mxfp8 & mxfp4 matmul")
if not (WITH_A_SCALE and WITH_B_SCALE):
pytest.skip("None scale has not been tested on XPU backend")
if not (A_DATA_TYPE == "float8e5" and B_DATA_TYPE == "float4"):
pytest.skip(f"(A: {A_DATA_TYPE}, B: {B_DATA_TYPE}) has not been tested on XPU backend")
if (BLOCK_M, BLOCK_N,
BLOCK_K) == (128, 256,
256) and CONST_SCALE and triton.runtime.driver.active.utils.get_device_properties(
triton.runtime.driver.active.get_current_device())["max_shared_mem"] < 196608:
pytest.skip("XPU: Not enough shared memory")
if not PACK_B_ALONG_K and B_DATA_TYPE != "float4":
pytest.xfail("Pack along K can only be False for float4")

Expand Down Expand Up @@ -1288,6 +1296,8 @@ def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bo
kernel_kwargs = {}
if is_hip():
kernel_kwargs["matrix_instr_nonkdim"] = nonKDim
if is_xpu() and (128, 256, 256) == (BLOCK_M, BLOCK_N, BLOCK_K) and not CONST_SCALE and not PACK_B_ALONG_K:
kernel_kwargs["num_warps"] = 8
out = mxfp8_mxfp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, stride_scale, a.stride(0), a.stride(1),
b.stride(0), b.stride(1), output.stride(0), output.stride(1), not CONST_SCALE,
dtype_converter[A_DATA_TYPE], dtype_converter[B_DATA_TYPE], BLOCK_M, BLOCK_N,
Expand Down
2 changes: 1 addition & 1 deletion scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ run_mxfp_tests() {
cd $TRITON_PROJ/python/test/unit

TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=mxfp \
run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu intel/test_mxfp_matmul.py
run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu intel/test_mxfp_matmul.py language/test_matmul.py::test_mxfp8_mxfp4_matmul
}

run_scaled_dot_tests() {
Expand Down
Loading