|
4 | 4 |
|
5 | 5 | import triton
|
6 | 6 | import triton.language as tl
|
7 |
| -from triton._internal_testing import is_blackwell, is_hopper, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy |
| 7 | +from triton._internal_testing import is_hopper, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy |
8 | 8 | from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor
|
9 | 9 | from typing import Optional
|
10 | 10 | from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3
|
@@ -1315,7 +1315,7 @@ def torch_gather_rows(input, idx, y, block_y):
|
1315 | 1315 | @pytest.mark.parametrize("BLOCK_X, BLOCK_Y", [(32, 32), (64, 128), (16, 128), (512, 16)])
|
1316 | 1316 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8])
|
1317 | 1317 | @pytest.mark.parametrize("y", [0, 32, 48])
|
1318 |
| -@pytest.mark.skipif(not is_blackwell(), reason="TMA Gather requires blackwell") |
| 1318 | +@pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper") |
1319 | 1319 | def test_tma_gather(X, Y, BLOCK_X, BLOCK_Y, dtype, y, device):
|
1320 | 1320 | if BLOCK_X > X or y + BLOCK_Y > Y:
|
1321 | 1321 | pytest.skip()
|
@@ -1367,7 +1367,7 @@ def tma_gather_dot_pipeline( #
|
1367 | 1367 | @pytest.mark.interpreter
|
1368 | 1368 | @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(16, 16, 16)])
|
1369 | 1369 | @pytest.mark.parametrize("K", [128])
|
1370 |
| -@pytest.mark.skipif(not is_blackwell(), reason="TMA Gather requires blackwell") |
| 1370 | +@pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper") |
1371 | 1371 | def test_tma_gather_dot_pipeline(BLOCK_M, BLOCK_N, BLOCK_K, K, device):
|
1372 | 1372 |
|
1373 | 1373 | def alloc_fn(size: int, align: int, steam):
|
@@ -1414,7 +1414,7 @@ def tma_scatter_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl.
|
1414 | 1414 | @pytest.mark.parametrize("BLOCK_X, BLOCK_Y", [(32, 32), (64, 128), (16, 128), (512, 16)])
|
1415 | 1415 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8])
|
1416 | 1416 | @pytest.mark.parametrize("y", [0, 32, 48])
|
1417 |
| -@pytest.mark.skipif(not is_blackwell(), reason="TMA Scatter requires blackwell") |
| 1417 | +@pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper") |
1418 | 1418 | def test_tma_scatter(X, Y, BLOCK_X, BLOCK_Y, dtype, y):
|
1419 | 1419 | if BLOCK_X > X or y + BLOCK_Y > Y:
|
1420 | 1420 | pytest.skip()
|
|
0 commit comments