55import re
66import triton
77import triton .language as tl
8- from triton ._internal_testing import is_hopper_or_newer , is_blackwell , is_hopper , is_hip , is_cuda
8+ from triton ._internal_testing import is_hopper_or_newer , is_blackwell , is_hopper , is_hip
99import triton .language .extra .tlx as tlx
1010from typing import Optional
1111import traceback
@@ -585,6 +585,7 @@ def test_cta_0_kernel(
585585 torch .testing .assert_close (output , expected_output )
586586
587587
588+ @pytest .mark .skipif (not is_hopper_or_newer (), reason = "Need Hopper or newer" )
588589def test_clock64 (device ):
589590
590591 @triton .jit
@@ -1245,6 +1246,7 @@ def tcgen5_dot_kernel2cta_tma_ws(a_ptr, stride_am, stride_ak, b_ptr, stride_bk,
12451246
12461247@pytest .mark .parametrize ("A_DATA_TYPE" , ["e5m2" , "e4m3" ])
12471248@pytest .mark .parametrize ("B_DATA_TYPE" , ["e5m2" , "e4m3" ])
1249+ @pytest .mark .skipif (not is_blackwell (), reason = "Need Blackwell" )
12481250def test_async_dot_scaled (A_DATA_TYPE , B_DATA_TYPE , device ):
12491251 """
12501252 Test D = (A * A_scale) * (B * B_scale)
@@ -1633,9 +1635,8 @@ def run_tlx_square(func, BLOCK_SIZE, device, expected_arrival_count=1):
16331635
16341636
16351637# Unit test for arrive/wait
1636- @pytest .mark .skipif (not (is_hip () or is_hopper_or_newer ()), reason = "Need Hopper or newer" )
1638+ @pytest .mark .skipif (not (is_hip () or is_hopper_or_newer ()), reason = "Need Hopper or newer or AMD " )
16371639@pytest .mark .parametrize ("BLOCK_SIZE" , [(1024 )])
1638- # def test_mbarriers(BLOCK_SIZE, device):
16391640def test_wait_arrive_non_ws (BLOCK_SIZE , device ):
16401641 expected_arrival_count = 4 if is_hip () else 1
16411642 kernel = run_tlx_square (tlx_square_non_ws , BLOCK_SIZE , device , expected_arrival_count = expected_arrival_count )
@@ -1652,7 +1653,6 @@ def test_wait_arrive_non_ws(BLOCK_SIZE, device):
16521653
16531654@pytest .mark .skipif (not is_hopper_or_newer (), reason = "Need Hopper or newer" )
16541655@pytest .mark .parametrize ("BLOCK_SIZE" , [(1024 )])
1655- # def test_mbarriers(BLOCK_SIZE, device):
16561656def test_wait_arrive_ws (BLOCK_SIZE , device ):
16571657 kernel = run_tlx_square (tlx_square_ws , BLOCK_SIZE , device )
16581658
@@ -1699,7 +1699,6 @@ def bar_live_kernel():
16991699
17001700@pytest .mark .skipif (not is_hopper_or_newer (), reason = "Need Hopper or newer" )
17011701@pytest .mark .parametrize ("BLOCK_SIZE" , [(1024 )])
1702- # def test_mbarriers(BLOCK_SIZE, device):
17031702def test_named_wait_arrive (BLOCK_SIZE , device ):
17041703
17051704 @triton .jit
@@ -2155,10 +2154,7 @@ def ws_error_kernel():
21552154 assert "ZeroDivisionError('division by zero')" in exc_msg , '\n \n Expected ZeroDivisionError but got: \n \n ' + exc_msg + '\n \n '
21562155
21572156
2158- @pytest .mark .skipif (
2159- not is_cuda () or torch .cuda .get_device_capability ()[0 ] < 9 ,
2160- reason = "Requires compute capability >= 9 for NV" ,
2161- )
2157+ @pytest .mark .skipif (not is_hopper_or_newer (), reason = "Need Hopper or newer" )
21622158@pytest .mark .parametrize ("BLOCK_SIZE" , [(64 )])
21632159def test_local_index (BLOCK_SIZE , device ):
21642160
@@ -2199,6 +2195,7 @@ def local_index(
21992195 torch .testing .assert_close (y , output )
22002196
22012197
2198+ @pytest .mark .skipif (not is_hopper_or_newer (), reason = "Need Hopper or newer" )
22022199def test_async_token_error (device ):
22032200
22042201 @triton .jit
0 commit comments