|
4 | 4 |
|
5 | 5 | import triton |
6 | 6 | import triton.language as tl |
7 | | -from triton._internal_testing import is_blackwell, is_hopper, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy |
| 7 | +from triton._internal_testing import is_hopper, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy |
8 | 8 | from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor |
9 | 9 | from typing import Optional |
10 | 10 | from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3 |
@@ -1315,7 +1315,7 @@ def torch_gather_rows(input, idx, y, block_y): |
1315 | 1315 | @pytest.mark.parametrize("BLOCK_X, BLOCK_Y", [(32, 32), (64, 128), (16, 128), (512, 16)]) |
1316 | 1316 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8]) |
1317 | 1317 | @pytest.mark.parametrize("y", [0, 32, 48]) |
1318 | | -@pytest.mark.skipif(not is_blackwell(), reason="TMA Gather requires blackwell") |
| 1318 | +@pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper") |
1319 | 1319 | def test_tma_gather(X, Y, BLOCK_X, BLOCK_Y, dtype, y, device): |
1320 | 1320 | if BLOCK_X > X or y + BLOCK_Y > Y: |
1321 | 1321 | pytest.skip() |
@@ -1367,7 +1367,7 @@ def tma_gather_dot_pipeline( # |
1367 | 1367 | @pytest.mark.interpreter |
1368 | 1368 | @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(16, 16, 16)]) |
1369 | 1369 | @pytest.mark.parametrize("K", [128]) |
1370 | | -@pytest.mark.skipif(not is_blackwell(), reason="TMA Gather requires blackwell") |
| 1370 | +@pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper") |
1371 | 1371 | def test_tma_gather_dot_pipeline(BLOCK_M, BLOCK_N, BLOCK_K, K, device): |
1372 | 1372 |
|
1373 | 1373 | def alloc_fn(size: int, align: int, steam): |
@@ -1414,7 +1414,7 @@ def tma_scatter_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl. |
1414 | 1414 | @pytest.mark.parametrize("BLOCK_X, BLOCK_Y", [(32, 32), (64, 128), (16, 128), (512, 16)]) |
1415 | 1415 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8]) |
1416 | 1416 | @pytest.mark.parametrize("y", [0, 32, 48]) |
1417 | | -@pytest.mark.skipif(not is_blackwell(), reason="TMA Scatter requires blackwell") |
| 1417 | +@pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper") |
1418 | 1418 | def test_tma_scatter(X, Y, BLOCK_X, BLOCK_Y, dtype, y): |
1419 | 1419 | if BLOCK_X > X or y + BLOCK_Y > Y: |
1420 | 1420 | pytest.skip() |
|
0 commit comments