Skip to content

Commit 9e9a7b7

Browse files
dshi7meta-codesync[bot]
authored andcommitted
modify skipIf for TLX UT (#694)
Summary: see [D87796244](https://www.internalfb.com/diff/D87796244) Pull Request resolved: #694 Reviewed By: htyu Differential Revision: D87812601 Pulled By: dshi7 fbshipit-source-id: 33027116d6e3e4f8b73369edeaad3c6eacd087d7
1 parent fd6c6b4 commit 9e9a7b7

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

python/test/unit/language/test_tlx.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
import triton
77
import triton.language as tl
8-
from triton._internal_testing import is_hopper_or_newer, is_blackwell, is_hopper, is_hip, is_cuda
8+
from triton._internal_testing import is_hopper_or_newer, is_blackwell, is_hopper, is_hip
99
import triton.language.extra.tlx as tlx
1010
from typing import Optional
1111
import traceback
@@ -585,6 +585,7 @@ def test_cta_0_kernel(
585585
torch.testing.assert_close(output, expected_output)
586586

587587

588+
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
588589
def test_clock64(device):
589590

590591
@triton.jit
@@ -1245,6 +1246,7 @@ def tcgen5_dot_kernel2cta_tma_ws(a_ptr, stride_am, stride_ak, b_ptr, stride_bk,
12451246

12461247
@pytest.mark.parametrize("A_DATA_TYPE", ["e5m2", "e4m3"])
12471248
@pytest.mark.parametrize("B_DATA_TYPE", ["e5m2", "e4m3"])
1249+
@pytest.mark.skipif(not is_blackwell(), reason="Need Blackwell")
12481250
def test_async_dot_scaled(A_DATA_TYPE, B_DATA_TYPE, device):
12491251
"""
12501252
Test D = (A * A_scale) * (B * B_scale)
@@ -1633,9 +1635,8 @@ def run_tlx_square(func, BLOCK_SIZE, device, expected_arrival_count=1):
16331635

16341636

16351637
# Unit test for arrive/wait
1636-
@pytest.mark.skipif(not (is_hip() or is_hopper_or_newer()), reason="Need Hopper or newer")
1638+
@pytest.mark.skipif(not (is_hip() or is_hopper_or_newer()), reason="Need Hopper or newer or AMD")
16371639
@pytest.mark.parametrize("BLOCK_SIZE", [(1024)])
1638-
# def test_mbarriers(BLOCK_SIZE, device):
16391640
def test_wait_arrive_non_ws(BLOCK_SIZE, device):
16401641
expected_arrival_count = 4 if is_hip() else 1
16411642
kernel = run_tlx_square(tlx_square_non_ws, BLOCK_SIZE, device, expected_arrival_count=expected_arrival_count)
@@ -1652,7 +1653,6 @@ def test_wait_arrive_non_ws(BLOCK_SIZE, device):
16521653

16531654
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
16541655
@pytest.mark.parametrize("BLOCK_SIZE", [(1024)])
1655-
# def test_mbarriers(BLOCK_SIZE, device):
16561656
def test_wait_arrive_ws(BLOCK_SIZE, device):
16571657
kernel = run_tlx_square(tlx_square_ws, BLOCK_SIZE, device)
16581658

@@ -1699,7 +1699,6 @@ def bar_live_kernel():
16991699

17001700
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
17011701
@pytest.mark.parametrize("BLOCK_SIZE", [(1024)])
1702-
# def test_mbarriers(BLOCK_SIZE, device):
17031702
def test_named_wait_arrive(BLOCK_SIZE, device):
17041703

17051704
@triton.jit
@@ -2155,10 +2154,7 @@ def ws_error_kernel():
21552154
assert "ZeroDivisionError('division by zero')" in exc_msg, '\n\nExpected ZeroDivisionError but got: \n\n' + exc_msg + '\n\n'
21562155

21572156

2158-
@pytest.mark.skipif(
2159-
not is_cuda() or torch.cuda.get_device_capability()[0] < 9,
2160-
reason="Requires compute capability >= 9 for NV",
2161-
)
2157+
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
21622158
@pytest.mark.parametrize("BLOCK_SIZE", [(64)])
21632159
def test_local_index(BLOCK_SIZE, device):
21642160

@@ -2199,6 +2195,7 @@ def local_index(
21992195
torch.testing.assert_close(y, output)
22002196

22012197

2198+
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
22022199
def test_async_token_error(device):
22032200

22042201
@triton.jit

0 commit comments

Comments
 (0)