Skip to content

Commit 6217cec

Browse files
[test_mxfp8_mxfp4_matmul] Use skip list instead of the in-code skips (#4087)
Fixes #3903
1 parent fdcf859 commit 6217cec

File tree

2 files changed

+983
-25
lines changed

2 files changed

+983
-25
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -922,31 +922,10 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR
922922
pytest.xfail("Pack along K can only be False for float4")
923923

924924
if is_xpu():
925-
if A_DATA_TYPE == B_DATA_TYPE == "float4":
926-
pytest.skip("https://github.com/intel/intel-xpu-backend-for-triton/issues/3777")
927-
elif not (WITH_B_SCALE or PACK_B_ALONG_K) and B_DATA_TYPE == "float4" and \
928-
A_DATA_TYPE in ("float8e5", "float8e4nv"):
929-
pytest.skip("https://github.com/intel/intel-xpu-backend-for-triton/issues/4045")
930-
elif WITH_B_SCALE and not PACK_B_ALONG_K and B_DATA_TYPE == "float4" and \
931-
A_DATA_TYPE in ("float8e5", "float8e4nv"):
932-
pytest.skip("https://github.com/intel/intel-xpu-backend-for-triton/issues/3908")
933-
elif (BLOCK_M, BLOCK_N, BLOCK_K) == (128, 256, 256):
934-
if triton.runtime.driver.active.utils.get_device_properties(
935-
triton.runtime.driver.active.get_current_device())["max_shared_mem"] < 196608:
936-
pytest.xfail("Not enough shared memory")
937-
else:
938-
pass
939-
elif (BLOCK_M, BLOCK_N, BLOCK_K) in ((128, 64, 128), (128, 128, 128)):
940-
pass
941-
elif (BLOCK_M, BLOCK_N, BLOCK_K) in (128, 128, 64):
942-
if A_DATA_TYPE in ("float8e5", "float8e4nv") and B_DATA_TYPE in ("float8e5", "float8e4nv") \
943-
and WITH_B_SCALE == CONST_SCALE \
944-
and WITH_A_SCALE and B_TRANS and PACK_B_ALONG_K:
945-
pytest.skip("https://github.com/intel/intel-xpu-backend-for-triton/issues/3677")
946-
pass
947-
else:
948-
# Some tests pass, but it's difficult to filter them out
949-
pytest.skip("https://github.com/intel/intel-xpu-backend-for-triton/issues/3677")
925+
required_sm = BLOCK_M * BLOCK_K * 2 + BLOCK_N * BLOCK_K * 2
926+
if triton.runtime.driver.active.utils.get_device_properties(
927+
triton.runtime.driver.active.get_current_device())["max_shared_mem"] < required_sm:
928+
pytest.xfail("Not enough shared memory")
950929

951930
if BLOCK_N == 256 and BLOCK_K == 256:
952931
NUM_STAGES = 2

0 commit comments

Comments
 (0)