diff --git a/python/test/unit/language/test_matmul.py b/python/test/unit/language/test_matmul.py index d440c38a7c..06b52c2773 100644 --- a/python/test/unit/language/test_matmul.py +++ b/python/test/unit/language/test_matmul.py @@ -1217,7 +1217,15 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE): pytest.skip("Float4 without scale is tested in test_block_scale_fp4") elif is_xpu(): - pytest.xfail("XPU does not natively support scaled mxfp8 & mxfp4 matmul") + if not (WITH_A_SCALE and WITH_B_SCALE): + pytest.skip("None scale has not been tested on XPU backend") + if not (A_DATA_TYPE == "float8e5" and B_DATA_TYPE == "float4"): + pytest.skip(f"(A: {A_DATA_TYPE}, B: {B_DATA_TYPE}) has not been tested on XPU backend") + if (BLOCK_M, BLOCK_N, + BLOCK_K) == (128, 256, + 256) and CONST_SCALE and triton.runtime.driver.active.utils.get_device_properties( + triton.runtime.driver.active.get_current_device())["max_shared_mem"] < 196608: + pytest.skip("XPU: Not enough shared memory") if not PACK_B_ALONG_K and B_DATA_TYPE != "float4": pytest.xfail("Pack along K can only be False for float4") @@ -1288,6 +1296,8 @@ def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bo kernel_kwargs = {} if is_hip(): kernel_kwargs["matrix_instr_nonkdim"] = nonKDim + if is_xpu() and (128, 256, 256) == (BLOCK_M, BLOCK_N, BLOCK_K) and not CONST_SCALE and not PACK_B_ALONG_K: + kernel_kwargs["num_warps"] = 8 out = mxfp8_mxfp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, stride_scale, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0), output.stride(1), not CONST_SCALE, dtype_converter[A_DATA_TYPE], dtype_converter[B_DATA_TYPE], BLOCK_M, BLOCK_N, diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index bc347f4a00..e890e0db17 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -402,7 +402,7 @@ run_mxfp_tests() { cd $TRITON_PROJ/python/test/unit TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=mxfp \ - run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu intel/test_mxfp_matmul.py + run_pytest_command -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu intel/test_mxfp_matmul.py language/test_matmul.py::test_mxfp8_mxfp4_matmul } run_scaled_dot_tests() {