Skip to content

Commit c6abb22

Browse files
authored
[Testing] Restore tma gather scatter fallback tests (#7619)
Fixes #7596
1 parent 0ca15d8 commit c6abb22

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import triton
66
import triton.language as tl
7-
from triton._internal_testing import is_blackwell, is_hopper, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy
7+
from triton._internal_testing import is_hopper, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy
88
from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor
99
from typing import Optional
1010
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3
@@ -1315,7 +1315,7 @@ def torch_gather_rows(input, idx, y, block_y):
13151315
@pytest.mark.parametrize("BLOCK_X, BLOCK_Y", [(32, 32), (64, 128), (16, 128), (512, 16)])
13161316
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8])
13171317
@pytest.mark.parametrize("y", [0, 32, 48])
1318-
@pytest.mark.skipif(not is_blackwell(), reason="TMA Gather requires blackwell")
1318+
@pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper")
13191319
def test_tma_gather(X, Y, BLOCK_X, BLOCK_Y, dtype, y, device):
13201320
if BLOCK_X > X or y + BLOCK_Y > Y:
13211321
pytest.skip()
@@ -1367,7 +1367,7 @@ def tma_gather_dot_pipeline( #
13671367
@pytest.mark.interpreter
13681368
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(16, 16, 16)])
13691369
@pytest.mark.parametrize("K", [128])
1370-
@pytest.mark.skipif(not is_blackwell(), reason="TMA Gather requires blackwell")
1370+
@pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper")
13711371
def test_tma_gather_dot_pipeline(BLOCK_M, BLOCK_N, BLOCK_K, K, device):
13721372

13731373
def alloc_fn(size: int, align: int, steam):
@@ -1414,7 +1414,7 @@ def tma_scatter_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl.
14141414
@pytest.mark.parametrize("BLOCK_X, BLOCK_Y", [(32, 32), (64, 128), (16, 128), (512, 16)])
14151415
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8])
14161416
@pytest.mark.parametrize("y", [0, 32, 48])
1417-
@pytest.mark.skipif(not is_blackwell(), reason="TMA Scatter requires blackwell")
1417+
@pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper")
14181418
def test_tma_scatter(X, Y, BLOCK_X, BLOCK_Y, dtype, y):
14191419
if BLOCK_X > X or y + BLOCK_Y > Y:
14201420
pytest.skip()

0 commit comments

Comments
 (0)