Skip to content

Commit e74db2d

Browse files
authored
[Gluon] Implement async TMA ops (except reduce) (#7004)
1 parent e5aa2ab commit e74db2d

File tree

6 files changed

+227
-2
lines changed

6 files changed

+227
-2
lines changed

python/src/gluon_ir.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,45 @@ void init_gluon_ir(py::module &&m) {
264264
pred, two_ctas, mbarriers,
265265
mbarrier_preds);
266266
})
267+
268+
.def("create_tensor_desc_to_tma_ptr",
269+
[](GluonOpBuilder &self, Value desc) -> Value {
270+
return self.create<ttng::TensorDescToTMAPtrOp>(desc);
271+
})
272+
.def("create_async_tma_copy_global_to_local",
273+
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &coord,
274+
Value barrier, Value result, Value pred) {
275+
self.create<ttng::AsyncTMACopyGlobalToLocalOp>(
276+
descPtr, coord, barrier, result, pred);
277+
})
278+
.def("create_async_tma_copy_local_to_global",
279+
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &coord,
280+
Value src) {
281+
self.create<ttng::AsyncTMACopyLocalToGlobalOp>(descPtr, coord,
282+
src);
283+
})
284+
.def("create_async_tma_reduce",
285+
[](GluonOpBuilder &self, triton::DescriptorReduceKind kind,
286+
Value descPtr, std::vector<Value> &coord, Value src) {
287+
self.create<ttng::AsyncTMAReduceOp>(kind, descPtr, coord, src);
288+
})
289+
.def("create_async_tma_store_wait",
290+
[](GluonOpBuilder &self, int pendings) {
291+
self.create<ttng::TMAStoreWaitOp>(pendings);
292+
})
293+
.def("create_async_tma_gather",
294+
[](GluonOpBuilder &self, Value descPtr, Value xOffsets,
295+
Value yOffset, Value barrier, Value result, Value pred) {
296+
self.create<ttng::AsyncTMAGatherOp>(descPtr, xOffsets, yOffset,
297+
barrier, result, pred);
298+
})
299+
.def("create_async_tma_scatter",
300+
[](GluonOpBuilder &self, Value descPtr, Value xOffsets,
301+
Value yOffset, Value src) {
302+
self.create<ttng::AsyncTMAScatterOp>(descPtr, xOffsets, yOffset,
303+
src);
304+
})
305+
267306
.def("create_warp_return",
268307
[](GluonOpBuilder &self) -> Operation * {
269308
return self.create<ttg::WarpReturnOp>();

python/test/gluon/test_frontend.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
from triton.experimental import gluon
77
from triton.experimental.gluon import language as ttgl
88
from triton.experimental.gluon.language.nvidia import blackwell
9-
from triton.experimental.gluon.language.nvidia.blackwell import mbarrier
9+
from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma
1010
from triton._filecheck import filecheck_test, run_parser
1111
import triton.language as tl
1212
from triton._internal_testing import is_cuda
13+
from triton.tools.tensor_descriptor import TensorDescriptor
1314
from triton.compiler.errors import CompilationError
1415

1516

@@ -408,6 +409,126 @@ def test_tcgen05_mma(fresh_knobs):
408409
""")
409410

410411

412+
@gluon.jit
413+
def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr, smem_layout: ttgl.constexpr):
414+
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout)
415+
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
416+
mbarrier.init(bar, count=1)
417+
418+
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem)
419+
mbarrier.expect(bar, XBLOCK * XBLOCK * ttgl.float16.primitive_bitwidth // 8)
420+
mbarrier.wait(bar, 0)
421+
422+
mbarrier.invalidate(bar)
423+
424+
tma.async_copy_shared_to_global(input_desc, [0, 0], smem)
425+
tma.store_wait(0)
426+
427+
428+
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="TMA requires at least Hopper")
429+
def test_async_tma(fresh_knobs):
430+
knobs.compilation.disable_line_info = True
431+
432+
input = torch.randn((1024, 1024), device="cuda", dtype=torch.float16)
433+
XBLOCK = 128
434+
input_desc = TensorDescriptor.from_tensor(input, [XBLOCK, XBLOCK])
435+
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2)
436+
437+
h = async_tma_kernel.warmup(input_desc, XBLOCK, shared_layout, grid=(1, ), num_warps=4)
438+
expecttest.assert_expected_inline(
439+
h.asm["source"], """\
440+
#loc = loc(unknown)
441+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
442+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
443+
#smem = #ttg.shared_memory
444+
module attributes {"ttg.num-warps" = 4 : i32} {
445+
tt.func public @async_tma_kernel(%arg0: !tt.tensordesc<tensor<128x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
446+
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
447+
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
448+
ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
449+
%c0_i32 = arith.constant 0 : i32 loc(#loc)
450+
%c0_i32_0 = arith.constant 0 : i32 loc(#loc)
451+
%true = arith.constant true loc(#loc)
452+
%2 = ttng.tensor_desc_to_tma_ptr %arg0 : !tt.tensordesc<tensor<128x128xf16>> to !tt.ptr<i8> loc(#loc)
453+
ttng.async_tma_copy_global_to_local %2[%c0_i32, %c0_i32_0] %0, %1, %true : !tt.ptr<i8>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
454+
%true_1 = arith.constant true loc(#loc)
455+
ttng.barrier_expect %1, 32768, %true_1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
456+
%c0_i32_2 = arith.constant 0 : i32 loc(#loc)
457+
%true_3 = arith.constant true loc(#loc)
458+
ttng.wait_barrier %1, %c0_i32_2, %true_3 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
459+
ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
460+
%c0_i32_4 = arith.constant 0 : i32 loc(#loc)
461+
%c0_i32_5 = arith.constant 0 : i32 loc(#loc)
462+
%3 = ttng.tensor_desc_to_tma_ptr %arg0 : !tt.tensordesc<tensor<128x128xf16>> to !tt.ptr<i8> loc(#loc)
463+
ttng.async_tma_copy_local_to_global %3[%c0_i32_4, %c0_i32_5] %0 : !tt.ptr<i8>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
464+
ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc)
465+
tt.return loc(#loc)
466+
} loc(#loc)
467+
} loc(#loc)
468+
""")
469+
470+
471+
@gluon.jit
472+
def async_tma_blackwell_kernel(input_desc, XBLOCK: ttgl.constexpr, smem_layout: ttgl.constexpr):
473+
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout)
474+
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
475+
mbarrier.init(bar, count=1)
476+
477+
offset_layout: tl.constexpr = ttgl.BlockedLayout([1, 4], [32, 1], [1, 4], [1, 0])
478+
x_offsets = ttgl.arange(0, XBLOCK, layout=ttgl.SliceLayout(0, offset_layout))
479+
tma.async_gather(input_desc, x_offsets, 0, bar, smem)
480+
mbarrier.expect(bar, XBLOCK * XBLOCK * ttgl.float16.primitive_bitwidth // 8)
481+
mbarrier.wait(bar, 0)
482+
483+
mbarrier.invalidate(bar)
484+
485+
tma.async_scatter(input_desc, x_offsets, 0, smem)
486+
tma.store_wait(0)
487+
488+
489+
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 10, reason="Requires Blackwell")
490+
def test_async_tma_blackwell(fresh_knobs):
491+
knobs.compilation.disable_line_info = True
492+
493+
input = torch.randn((1024, 1024), device="cuda", dtype=torch.float16)
494+
XBLOCK = 128
495+
input_desc = TensorDescriptor.from_tensor(input, [1, XBLOCK])
496+
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2)
497+
498+
h = async_tma_blackwell_kernel.warmup(input_desc, XBLOCK, shared_layout, grid=(1, ), num_warps=4)
499+
expecttest.assert_expected_inline(
500+
h.asm["source"], """\
501+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
502+
#loc = loc(unknown)
503+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
504+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
505+
#smem = #ttg.shared_memory
506+
module attributes {"ttg.num-warps" = 4 : i32} {
507+
tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc<tensor<1x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
508+
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
509+
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
510+
ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
511+
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
512+
%true = arith.constant true loc(#loc)
513+
%c0_i32 = arith.constant 0 : i32 loc(#loc)
514+
%3 = ttng.tensor_desc_to_tma_ptr %arg0 : !tt.tensordesc<tensor<1x128xf16>> to !tt.ptr<i8> loc(#loc)
515+
ttng.async_tma_gather %3[%2, %c0_i32] %0, %1, %true : !tt.ptr<i8>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 loc(#loc)
516+
%true_0 = arith.constant true loc(#loc)
517+
ttng.barrier_expect %1, 32768, %true_0 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
518+
%c0_i32_1 = arith.constant 0 : i32 loc(#loc)
519+
%true_2 = arith.constant true loc(#loc)
520+
ttng.wait_barrier %1, %c0_i32_1, %true_2 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
521+
ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
522+
%4 = ttng.tensor_desc_to_tma_ptr %arg0 : !tt.tensordesc<tensor<1x128xf16>> to !tt.ptr<i8> loc(#loc)
523+
%c0_i32_3 = arith.constant 0 : i32 loc(#loc)
524+
ttng.async_tma_scatter %4[%2, %c0_i32_3] %0 : !tt.ptr<i8>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
525+
ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc)
526+
tt.return loc(#loc)
527+
} loc(#loc)
528+
} loc(#loc)
529+
""")
530+
531+
411532
def test_mlir_attr_error():
412533

413534
@gluon.jit

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from triton.experimental.gluon.language import _core as ttgl
66
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
77

8+
from . import tma
89
from ..hopper import mbarrier
910

1011
if TYPE_CHECKING:
@@ -16,6 +17,7 @@
1617
"tensor_memory_descriptor",
1718
"allocate_tensor_memory",
1819
"mbarrier",
20+
"tma",
1921
]
2022

2123

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from triton.experimental.gluon.language._core import builtin
2+
import triton.experimental.gluon.language._core as ttgl
3+
from triton.experimental.gluon.language.nvidia.hopper.tma import (
4+
_tensor_desc_to_tma_ptr,
5+
async_copy_global_to_shared,
6+
async_copy_shared_to_global,
7+
store_wait,
8+
)
9+
10+
__all__ = [
11+
"async_gather",
12+
"async_scatter",
13+
"async_copy_global_to_shared",
14+
"async_copy_shared_to_global",
15+
"store_wait",
16+
]
17+
18+
19+
@builtin
20+
def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _builder=None):
21+
pred = ttgl.to_tensor(pred, _builder=_builder)
22+
y_offset = ttgl.to_tensor(y_offset, _builder=_builder)
23+
tma_ptr = _tensor_desc_to_tma_ptr(tensor_desc, _builder)
24+
_builder.create_async_tma_gather(tma_ptr, x_offsets.handle, y_offset.handle, barrier.handle, result.handle,
25+
pred.handle)
26+
27+
28+
@builtin
29+
def async_scatter(tensor_desc, x_offsets, y_offset, src, _builder=None):
30+
tma_ptr = _tensor_desc_to_tma_ptr(tensor_desc, _builder)
31+
y_offset = ttgl.to_tensor(y_offset, _builder=_builder)
32+
_builder.create_async_tma_scatter(tma_ptr, x_offsets.handle, y_offset.handle, src.handle)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from . import mbarrier
2+
from . import tma
23

3-
__all__ = ["mbarrier"]
4+
__all__ = ["mbarrier", "tma"]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from triton.language.semantic import _convert_to_ir_values
2+
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
3+
import triton.experimental.gluon.language._core as ttgl
4+
5+
__all__ = ["async_copy_global_to_shared", "async_copy_shared_to_global", "store_wait"]
6+
7+
8+
def _tensor_desc_to_tma_ptr(tensor_desc, builder):
9+
return builder.create_tensor_desc_to_tma_ptr(tensor_desc.handle)
10+
11+
12+
@builtin
13+
def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True, _builder=None):
14+
coord = _convert_to_ir_values(_builder, coord, require_i64=False)
15+
pred = ttgl.to_tensor(pred, _builder=_builder)
16+
tma_ptr = _tensor_desc_to_tma_ptr(tensor_desc, _builder)
17+
_builder.create_async_tma_copy_global_to_local(tma_ptr, coord, barrier.handle, result.handle, pred.handle)
18+
19+
20+
@builtin
21+
def async_copy_shared_to_global(tensor_desc, coord, src, _builder=None):
22+
coord = _convert_to_ir_values(_builder, coord, require_i64=False)
23+
tma_ptr = _tensor_desc_to_tma_ptr(tensor_desc, _builder)
24+
_builder.create_async_tma_copy_local_to_global(tma_ptr, coord, src.handle)
25+
26+
27+
@builtin
28+
def store_wait(pendings, _builder=None):
29+
pendings = _unwrap_if_constexpr(pendings)
30+
_builder.create_async_tma_store_wait(pendings)

0 commit comments

Comments
 (0)