Skip to content

Commit 4d85824

Browse files
authored
[NVIDIA] Enable TMA gather4 on sm_120 and sm_121 (#8498)
- Enable cp.async.bulk.tensor.2d.tile::gather4.shared on sm_120 and sm_121. - Skip TMA scatter4 test on sm_120 since it is unsupported by hardware. Note: All other TMA features except for cluster-related ones are supported on sm_120.
1 parent 4734af3 commit 4d85824

File tree

4 files changed

+8
-3
lines changed

4 files changed

+8
-3
lines changed

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import triton
66
import triton.language as tl
7-
from triton._internal_testing import is_hopper, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy
7+
from triton._internal_testing import is_hopper, is_sm12x, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy
88
from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor
99
from typing import Optional
1010
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3
@@ -1474,6 +1474,7 @@ def tma_scatter_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl.
14741474
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8])
14751475
@pytest.mark.parametrize("y", [0, 32, 48])
14761476
@pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper")
1477+
@pytest.mark.skipif(is_sm12x(), reason="TMA Scatter is not supported on sm120")
14771478
def test_tma_scatter(X, Y, BLOCK_X, BLOCK_Y, dtype, y, device):
14781479
if BLOCK_X > X or y + BLOCK_Y > Y:
14791480
pytest.skip()

python/triton/_internal_testing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def is_hopper():
5454
return is_cuda() and torch.cuda.get_device_capability()[0] == 9
5555

5656

57+
def is_sm12x():
58+
return is_cuda() and torch.cuda.get_device_capability()[0] == 12
59+
60+
5761
def is_hip():
5862
target = get_current_target()
5963
return False if target is None else target.backend == "hip"

test/Conversion/tma_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ tt.func @tma_gather_simple(%arg0: !tt.tensordesc<tensor<1x128xbf16, #shared1>>,
6565

6666
// CHECK: [[OFFSET0:%.*]] = zext nneg i32 [[WARP_STRIDE]] to i64
6767
// CHECK: [[BASEPTR0:%.*]] = getelementptr bfloat, ptr addrspace(3) [[BASE_PTR]], i64 [[OFFSET0]]
68-
// CHECK: "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4, $5, $6, $7}], [$8];", "b,r,l,r,r,r,r,r,r"
68+
// CHECK: "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4, $5, $6, $7}], [$8];", "b,r,l,r,r,r,r,r,r"
6969
// CHECK-SAME: (i1 [[PRED]], ptr addrspace(3) [[BASEPTR0]], ptr nonnull %0, i32 [[Y0]], i32 [[IDX0]], i32 [[IDX1]], i32 [[IDX2]], i32 [[IDX3]], ptr addrspace(3) [[BAR]])
7070

7171
// CHECK: [[BASEPTR1:%.*]] = getelementptr i8, ptr addrspace(3) [[BASEPTR0]], i64 4096

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1717,7 +1717,7 @@ LogicalResult AsyncTMAGatherOpConversion::matchAndRewrite(
17171717
auto callback = [&](Value pred, Value shMemPtr, Value yOffset,
17181718
ArrayRef<Value> xOffsets) {
17191719
std::string tmaInst = "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared"
1720-
"::cluster.global.mbarrier::complete_tx::bytes "
1720+
"::cta.global.mbarrier::complete_tx::bytes "
17211721
"[$1], [$2, {$3, $4, $5, $6, $7}], [$8];";
17221722

17231723
PTXBuilder ptxBuilder;

0 commit comments

Comments
 (0)