Skip to content

Commit 0cc2c18

Browse files
[TEST] Fix test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma (#4402)
There are 6 failed, 18 passed for `test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma`. The failures are cause by triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 196608, Hardware limit: 131072. This change checks and marks out of memory cases as XFAIL.
1 parent d599763 commit 0cc2c18

File tree

3 files changed

+6
-21
lines changed

3 files changed

+6
-21
lines changed

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,12 @@ def test_mxfp8_mxfp4_matmul_tma(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES,
12701270
if BLOCK_K < K and is_cuda() and torch.cuda.get_device_capability(0)[0] != 10:
12711271
pytest.skip("Currently broken on hopper")
12721272

1273+
required_sm = BLOCK_M * BLOCK_K * 2 + BLOCK_K * BLOCK_N * 2
1274+
max_sm = triton.runtime.driver.active.utils.get_device_properties(
1275+
triton.runtime.driver.active.get_current_device())["max_shared_mem"]
1276+
if is_xpu() and required_sm > max_sm:
1277+
pytest.xfail(f"Not enough shared memory for the given block size ({BLOCK_M}, {BLOCK_N}, {BLOCK_K})")
1278+
12731279
a = torch.randint(20, 40, (M, K), dtype=torch.uint8).view(torch.float8_e5m2).to(device)
12741280

12751281
dtype_src_str = "float8e5"

scripts/skiplist/default/language.txt

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,2 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/3870
22
python/test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float32-float32]
3-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4222
4-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[1-128-128-256-8192-8192-8192]
5-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[1-128-256-128-8192-8192-8192]
6-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[1-128-256-256-1024-512-256]
7-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[1-128-256-256-128-256-256]
8-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[1-128-256-256-8192-8192-8192]
9-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[3-128-128-256-8192-8192-8192]
10-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[3-128-256-128-8192-8192-8192]
11-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[3-128-256-256-1024-512-256]
12-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[3-128-256-256-128-256-256]
13-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[3-128-256-256-8192-8192-8192]

scripts/skiplist/lts/language.txt

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -369,15 +369,5 @@ python/test/unit/language/test_pipeliner.py::test_indirect_matmul[5-128-128-64]
369369
python/test/unit/language/test_pipeliner.py::test_indirect_matmul[5-128-64-128]
370370
python/test/unit/language/test_core.py::test_convert_mma2mma[mma_pair0-float16-256-256]
371371
python/test/unit/language/test_matmul.py::test_lhs_in_tmem
372-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[1-128-128-256-8192-8192-8192]
373-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[1-128-256-128-8192-8192-8192]
374-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[1-128-256-256-1024-512-256]
375-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[1-128-256-256-128-256-256]
376-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[1-128-256-256-8192-8192-8192]
377-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[3-128-128-256-8192-8192-8192]
378-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[3-128-256-128-8192-8192-8192]
379-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[3-128-256-256-1024-512-256]
380-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[3-128-256-256-128-256-256]
381-
python/test/unit/language/test_tensor_descriptor.py::test_mxfp8_mxfp4_matmul_tma[3-128-256-256-8192-8192-8192]
382372
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_batched_gemm_2d_tma
383373
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_batched_gemm_3d_tma

0 commit comments

Comments
 (0)