Skip to content

Commit 3f44826

Browse files
[TEST] Fix UT and tutorial failures from b39c1e1
Signed-off-by: Whitney Tsang <[email protected]>
1 parent d18a065 commit 3f44826

File tree

4 files changed

+56
-9
lines changed

4 files changed

+56
-9
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import triton.tools.experimental_descriptor
77
from test_mxfp import MXFP4Tensor, MXScaleTensor
88
import re
9-
from triton._internal_testing import is_cuda, is_hip, is_hip_mi200
9+
from triton._internal_testing import is_cuda, is_hip, is_hip_mi200, is_xpu
1010

1111

1212
def f8_to_f16(x, dtype):
@@ -82,7 +82,7 @@ def get_src_element_ty_size(dtype_str):
8282
def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, NUM_CTAS,
8383
device):
8484
if NUM_CTAS > 1 and (not is_cuda() or torch.cuda.get_device_capability()[0] < 9):
85-
pytest.skip("Clusters requires nvidia compute capability >= 9")
85+
pytest.xfail("Clusters requires nvidia compute capability >= 9")
8686
if is_hip() and ((BLOCK_K * BLOCK_M + BLOCK_K * BLOCK_N) * NUM_STAGES * get_src_element_ty_size(dtype_src_str)
8787
> 65536):
8888
pytest.skip("HIP path requires less than 64KB of shared memory")
@@ -316,8 +316,11 @@ def fp8e8m0_to_float32(scale):
316316
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (256, 128, 128), (128, 256, 128),
317317
(128, 256, 256), (128, 128, 64), (128, 64, 128)])
318318
@pytest.mark.parametrize("NUM_STAGES", [1, 3])
319-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
319+
@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] < 10,
320+
reason="Requires compute capability >= 10")
320321
def test_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, device):
322+
if is_xpu():
323+
pytest.skip("FIXME: Fail RuntimeError on XPU")
321324
if BLOCK_N == 256 and BLOCK_K == 256:
322325
NUM_STAGES = min(NUM_STAGES, 2)
323326
torch.manual_seed(42)
@@ -442,8 +445,11 @@ def block_scale_mxfp_matmul( #
442445
(128, 128, 256), (128, 256, 256)])
443446
@pytest.mark.parametrize("NUM_STAGES", [1, 2, 4])
444447
@pytest.mark.parametrize("USE_2D_SCALE_LOAD", [False, True])
445-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
448+
@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] < 10,
449+
reason="Requires compute capability >= 10")
446450
def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_2D_SCALE_LOAD, device):
451+
if is_xpu():
452+
pytest.skip("FIXME: Fail RuntimeError on XPU")
447453
if BLOCK_N == 256 and BLOCK_K == 256:
448454
NUM_STAGES = min(NUM_STAGES, 2)
449455
elif BLOCK_K == 256:
@@ -564,7 +570,8 @@ def lhs_in_tmem_kernel( #
564570
(128, 256, 256), (128, 128, 64), (128, 64, 128)])
565571
@pytest.mark.parametrize("a_trans", [False, True])
566572
@pytest.mark.parametrize("dtype_src_str", ["float32", "float16", "float8e5"])
567-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
573+
@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] < 10,
574+
reason="Requires compute capability >= 10")
568575
def test_lhs_in_tmem(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, a_trans, dtype_src_str, device, monkeypatch):
569576
_knob_promote_lhs_to_tmem(monkeypatch)
570577
if M != BLOCK_M or N != BLOCK_N or K != BLOCK_K:
@@ -628,8 +635,11 @@ def lhs_in_tmem_kernel_mxfp( #
628635
tl.store(output_ptrs, accumulator)
629636

630637

631-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
638+
@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] < 10,
639+
reason="Requires compute capability >= 10")
632640
def test_lhs_in_tmem_mxfp(device, monkeypatch):
641+
if is_xpu():
642+
pytest.skip("FIXME: failed to legalize operation 'tt.dot_scaled' on XPU")
633643
_knob_promote_lhs_to_tmem(monkeypatch)
634644
M, N, K = 128, 64, 32
635645
torch.manual_seed(42)
@@ -712,8 +722,11 @@ def block_scale_fp4_matmul( #
712722
(128, 256, 256), (128, 128, 64), (128, 64, 128)])
713723
@pytest.mark.parametrize(("scale_type", "VEC_SIZE"), [("float8_e8m0fnu", 32), ("float8_e4m3fn", 16)],
714724
ids=["mxfp4", "nvfp4"])
715-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
725+
@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] < 10,
726+
reason="Requires compute capability >= 10")
716727
def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, scale_type, device):
728+
if is_xpu():
729+
pytest.skip("FIXME: failed to legalize operation 'tt.dot_scaled' on XPU")
717730
NUM_STAGES = 1
718731
torch.manual_seed(42)
719732
a_mxfp4 = MXFP4Tensor(size=(M, K), device=device).random()

python/test/unit/language/test_pipeliner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,9 +504,11 @@ def matmul_kernel_persistent_scatter(a_ptr, b_ptr, c_ptr, #
504504
c_desc.scatter(c, offs_am + tl.arange(0, BLOCK_SIZE_M), offs_bn)
505505

506506

507-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10,
507+
@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] != 10,
508508
reason="TMA Scatter only works on cloud Blackwell Chips")
509509
def test_scatter_pipeline(device):
510+
if is_xpu():
511+
pytest.xfail("XPU does not support TMA scatter")
510512

511513
def alloc_fn(size, alignment, stream):
512514
return torch.empty(size, device="cuda", dtype=torch.int8)

python/tutorials/06-fused-attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
339339
def keep_tma(conf):
340340
BLOCK_M = conf.kwargs["BLOCK_M"]
341341
BLOCK_N = conf.kwargs["BLOCK_N"]
342-
if (torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8):
342+
if (is_cuda() and torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128
343+
and conf.num_warps == 8):
343344
return False
344345
return True
345346

scripts/skiplist/lts/language.txt

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2014,3 +2014,34 @@ test/unit/language/test_core.py::test_scaled_dot[64-64-64-True-True-True-e5m2-e5
20142014
test/unit/language/test_core.py::test_scaled_dot[64-64-64-True-True-True-e5m2-fp16-4-16-1]
20152015
test/unit/language/test_core.py::test_trans_reshape
20162016
test/unit/language/test_pipeliner.py::test_pipeline_matmul[True]
2017+
test/unit/language/test_matmul.py::test_lhs_in_tmem[float16-False-128-128-128-128-128-128]
2018+
test/unit/language/test_matmul.py::test_lhs_in_tmem[float16-True-128-128-128-128-128-128]
2019+
test/unit/language/test_matmul.py::test_lhs_in_tmem[float32-False-128-128-128-128-128-128]
2020+
test/unit/language/test_matmul.py::test_lhs_in_tmem[float32-True-128-128-128-128-128-128]
2021+
test/unit/language/test_matmul.py::test_lhs_in_tmem[float8e5-False-128-128-128-128-128-128]
2022+
test/unit/language/test_matmul.py::test_lhs_in_tmem[float8e5-True-128-128-128-128-128-128]
2023+
test/unit/language/test_matmul.py::test_simple_matmul[4-1-128-128-16-4-float16-float16]
2024+
test/unit/language/test_matmul.py::test_simple_matmul[4-1-256-128-32-4-float16-float16]
2025+
test/unit/language/test_matmul.py::test_simple_matmul[4-1-256-128-32-4-float32-float16]
2026+
test/unit/language/test_matmul.py::test_simple_matmul[4-1-256-128-32-4-float32-float8e5]
2027+
test/unit/language/test_matmul.py::test_simple_matmul[4-1-32-32-32-4-float16-float16]
2028+
test/unit/language/test_matmul.py::test_simple_matmul[4-1-512-64-32-2-float16-float16]
2029+
test/unit/language/test_matmul.py::test_simple_matmul[4-1-512-64-32-2-float32-float16]
2030+
test/unit/language/test_matmul.py::test_simple_matmul[4-1-512-64-32-2-float32-float8e5]
2031+
test/unit/language/test_matmul.py::test_simple_matmul[4-1-64-128-32-4-float16-float16]
2032+
test/unit/language/test_matmul.py::test_simple_matmul[4-1-64-512-32-2-float16-float16]
2033+
test/unit/language/test_matmul.py::test_simple_matmul[8-1-128-128-16-4-float16-float16]
2034+
test/unit/language/test_matmul.py::test_simple_matmul[8-1-256-128-32-4-float16-float16]
2035+
test/unit/language/test_matmul.py::test_simple_matmul[8-1-32-32-32-4-float16-float16]
2036+
test/unit/language/test_matmul.py::test_simple_matmul[8-1-512-64-32-2-float16-float16]
2037+
test/unit/language/test_matmul.py::test_simple_matmul[8-1-64-128-32-4-float16-float16]
2038+
test/unit/language/test_matmul.py::test_simple_matmul[8-1-64-512-32-2-float16-float16]
2039+
test/unit/language/test_pipeliner.py::test_indirect_matmul[1-128-128-128]
2040+
test/unit/language/test_pipeliner.py::test_indirect_matmul[1-128-128-64]
2041+
test/unit/language/test_pipeliner.py::test_indirect_matmul[1-128-64-128]
2042+
test/unit/language/test_pipeliner.py::test_indirect_matmul[3-128-128-128]
2043+
test/unit/language/test_pipeliner.py::test_indirect_matmul[3-128-128-64]
2044+
test/unit/language/test_pipeliner.py::test_indirect_matmul[3-128-64-128]
2045+
test/unit/language/test_pipeliner.py::test_indirect_matmul[5-128-128-128]
2046+
test/unit/language/test_pipeliner.py::test_indirect_matmul[5-128-128-64]
2047+
test/unit/language/test_pipeliner.py::test_indirect_matmul[5-128-64-128]

0 commit comments

Comments
 (0)