|
6 | 6 | import triton.tools.experimental_descriptor |
7 | 7 | from test_mxfp import MXFP4Tensor, MXScaleTensor |
8 | 8 | import re |
9 | | -from triton._internal_testing import is_cuda, is_hip, is_hip_mi200 |
| 9 | +from triton._internal_testing import is_cuda, is_hip, is_hip_mi200, is_xpu |
10 | 10 |
|
11 | 11 |
|
12 | 12 | def f8_to_f16(x, dtype): |
@@ -82,7 +82,7 @@ def get_src_element_ty_size(dtype_str): |
82 | 82 | def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, NUM_CTAS, |
83 | 83 | device): |
84 | 84 | if NUM_CTAS > 1 and (not is_cuda() or torch.cuda.get_device_capability()[0] < 9): |
85 | | - pytest.skip("Clusters requires nvidia compute capability >= 9") |
| 85 | + pytest.xfail("Clusters requires nvidia compute capability >= 9") |
86 | 86 | if is_hip() and ((BLOCK_K * BLOCK_M + BLOCK_K * BLOCK_N) * NUM_STAGES * get_src_element_ty_size(dtype_src_str) |
87 | 87 | > 65536): |
88 | 88 | pytest.skip("HIP path requires less than 64KB of shared memory") |
@@ -316,8 +316,11 @@ def fp8e8m0_to_float32(scale): |
316 | 316 | @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (256, 128, 128), (128, 256, 128), |
317 | 317 | (128, 256, 256), (128, 128, 64), (128, 64, 128)]) |
318 | 318 | @pytest.mark.parametrize("NUM_STAGES", [1, 3]) |
319 | | -@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10") |
| 319 | +@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] < 10, |
| 320 | + reason="Requires compute capability >= 10") |
320 | 321 | def test_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, device): |
| 322 | + if is_xpu(): |
| 323 | + pytest.skip("FIXME: Fail RuntimeError on XPU") |
321 | 324 | if BLOCK_N == 256 and BLOCK_K == 256: |
322 | 325 | NUM_STAGES = min(NUM_STAGES, 2) |
323 | 326 | torch.manual_seed(42) |
@@ -442,8 +445,11 @@ def block_scale_mxfp_matmul( # |
442 | 445 | (128, 128, 256), (128, 256, 256)]) |
443 | 446 | @pytest.mark.parametrize("NUM_STAGES", [1, 2, 4]) |
444 | 447 | @pytest.mark.parametrize("USE_2D_SCALE_LOAD", [False, True]) |
445 | | -@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10") |
| 448 | +@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] < 10, |
| 449 | + reason="Requires compute capability >= 10") |
446 | 450 | def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_2D_SCALE_LOAD, device): |
| 451 | + if is_xpu(): |
| 452 | + pytest.skip("FIXME: Fail RuntimeError on XPU") |
447 | 453 | if BLOCK_N == 256 and BLOCK_K == 256: |
448 | 454 | NUM_STAGES = min(NUM_STAGES, 2) |
449 | 455 | elif BLOCK_K == 256: |
@@ -564,7 +570,8 @@ def lhs_in_tmem_kernel( # |
564 | 570 | (128, 256, 256), (128, 128, 64), (128, 64, 128)]) |
565 | 571 | @pytest.mark.parametrize("a_trans", [False, True]) |
566 | 572 | @pytest.mark.parametrize("dtype_src_str", ["float32", "float16", "float8e5"]) |
567 | | -@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10") |
| 573 | +@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] < 10, |
| 574 | + reason="Requires compute capability >= 10") |
568 | 575 | def test_lhs_in_tmem(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, a_trans, dtype_src_str, device, monkeypatch): |
569 | 576 | _knob_promote_lhs_to_tmem(monkeypatch) |
570 | 577 | if M != BLOCK_M or N != BLOCK_N or K != BLOCK_K: |
@@ -628,8 +635,11 @@ def lhs_in_tmem_kernel_mxfp( # |
628 | 635 | tl.store(output_ptrs, accumulator) |
629 | 636 |
|
630 | 637 |
|
631 | | -@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10") |
| 638 | +@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] < 10, |
| 639 | + reason="Requires compute capability >= 10") |
632 | 640 | def test_lhs_in_tmem_mxfp(device, monkeypatch): |
| 641 | + if is_xpu(): |
| 642 | + pytest.skip("FIXME: failed to legalize operation 'tt.dot_scaled' on XPU") |
633 | 643 | _knob_promote_lhs_to_tmem(monkeypatch) |
634 | 644 | M, N, K = 128, 64, 32 |
635 | 645 | torch.manual_seed(42) |
@@ -712,8 +722,11 @@ def block_scale_fp4_matmul( # |
712 | 722 | (128, 256, 256), (128, 128, 64), (128, 64, 128)]) |
713 | 723 | @pytest.mark.parametrize(("scale_type", "VEC_SIZE"), [("float8_e8m0fnu", 32), ("float8_e4m3fn", 16)], |
714 | 724 | ids=["mxfp4", "nvfp4"]) |
715 | | -@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10") |
| 725 | +@pytest.mark.skipif(is_cuda() and torch.cuda.get_device_capability()[0] < 10, |
| 726 | + reason="Requires compute capability >= 10") |
716 | 727 | def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, scale_type, device): |
| 728 | + if is_xpu(): |
| 729 | + pytest.skip("FIXME: failed to legalize operation 'tt.dot_scaled' on XPU") |
717 | 730 | NUM_STAGES = 1 |
718 | 731 | torch.manual_seed(42) |
719 | 732 | a_mxfp4 = MXFP4Tensor(size=(M, K), device=device).random() |
|
0 commit comments