Skip to content

Commit db0c34c

Browse files
authored
[AMD] Disable unsupported matmul tests (#6439)
Disabled: - language/test_matmul.py::test_blocked_scale_mxfp - language/test_matmul.py::test_lhs_in_tmem - language/test_matmul.py::test_lhs_in_tmem_mxfp Signed-off-by: Ilya Veselov <[email protected]>
1 parent 11b288c commit db0c34c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def block_scale_mxfp_matmul( #
468468
(128, 128, 256), (128, 256, 256)])
469469
@pytest.mark.parametrize("NUM_STAGES", [1, 2, 4])
470470
@pytest.mark.parametrize("USE_2D_SCALE_LOAD", [False, True])
471-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
471+
@pytest.mark.skipif(is_hip() or torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
472472
def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_2D_SCALE_LOAD, device):
473473
if BLOCK_N == 256 and BLOCK_K == 256:
474474
NUM_STAGES = min(NUM_STAGES, 2)
@@ -540,7 +540,7 @@ def flatten_scale(scale):
540540
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128), (64, 128, 32), (128, 256, 32)])
541541
@pytest.mark.parametrize("a_trans", [False, True])
542542
@pytest.mark.parametrize("dtype_src_str", ["float32", "float16", "float8e5"])
543-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
543+
@pytest.mark.skipif(is_hip() or torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
544544
def test_lhs_in_tmem(BLOCK_M, BLOCK_N, BLOCK_K, a_trans, dtype_src_str, device, monkeypatch):
545545
M = 1024
546546
N = 512
@@ -604,7 +604,7 @@ def lhs_in_tmem_kernel_mxfp( #
604604
tl.store(output_ptrs, accumulator)
605605

606606

607-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
607+
@pytest.mark.skipif(is_hip() or torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
608608
def test_lhs_in_tmem_mxfp(device, monkeypatch):
609609
_knob_promote_lhs_to_tmem(monkeypatch)
610610
M, N, K = 128, 64, 32

0 commit comments

Comments
 (0)