Skip to content

Commit d0c65f9

Browse files
authored
[Gluon] Implement TensorDescriptor kernel arguments (#7142)
1 parent 3597ff1 commit d0c65f9

File tree

15 files changed

+207
-25
lines changed

15 files changed

+207
-25
lines changed

python/src/gluon_ir.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,15 @@ void init_gluon_ir(py::module &&m) {
270270
assert(ty.getEncoding());
271271
return layoutToGluon(ty.getEncoding());
272272
})
273+
.def("get_tensor_descriptor_layout_type",
274+
[](GluonOpBuilder &self, Type blockType, bool isSigned,
275+
Attribute layout) -> Type {
276+
auto ctx = self.getContext();
277+
auto blockTy = cast<RankedTensorType>(blockType);
278+
auto blockTyLayout = RankedTensorType::get(
279+
blockTy.getShape(), blockTy.getElementType(), layout);
280+
return triton::TensorDescType::get(ctx, blockTyLayout, isSigned);
281+
})
273282
.def("create_convert_layout",
274283
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
275284
return self.create<ttg::ConvertLayoutOp>(resultTy, value);

python/test/gluon/test_core.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import torch
22
import pytest
33

4+
from triton._internal_testing import is_cuda
45
from triton.experimental import gluon
56
from triton.experimental.gluon import language as ttgl
7+
from triton.experimental.gluon.language.nvidia.hopper import tma
68

79

810
@gluon.jit
@@ -31,3 +33,29 @@ def test_copy_kernel(layout, XBLOCK):
3133

3234
copy_kernel[(4, )](out, inp, inp.numel(), XBLOCK, layout, num_warps=layout.warps_per_cta[0])
3335
torch.testing.assert_close(out, inp)
36+
37+
38+
@gluon.jit
39+
def tma_kernel(desc):
40+
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0])
41+
value = ttgl.full(desc.block_shape, 0, desc.dtype, layout)
42+
alloc = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value)
43+
tma.async_copy_shared_to_global(desc, [0, 0], alloc)
44+
tma.store_wait(0)
45+
alloc._keep_alive()
46+
47+
48+
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires Hopper")
49+
def test_tma():
50+
out = torch.ones((16, 16), dtype=torch.float16, device="cuda")
51+
layout = ttgl.NVMMASharedLayout(
52+
swizzle_byte_width=32,
53+
element_bitwidth=16,
54+
rank=2,
55+
transposed=False,
56+
fp4_padded=False,
57+
)
58+
59+
desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(out, [16, 16], layout)
60+
tma_kernel[(1, )](desc)
61+
torch.testing.assert_close(out, torch.zeros_like(out))

python/test/gluon/test_frontend.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from triton.experimental.gluon import language as ttgl
1010
from triton.experimental.gluon.language.nvidia import blackwell
1111
from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma, TensorMemoryLayout
12+
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
1213
from triton._filecheck import filecheck_test, run_parser
1314
import triton.language as tl
1415
from triton._internal_testing import is_cuda
15-
from triton.tools.tensor_descriptor import TensorDescriptor
1616
from triton.compiler.errors import CompilationError
1717

1818
TARGET_PAT = re.compile('ttg.target = "[^"]*"')
@@ -434,8 +434,8 @@ def test_tcgen05_mma(fresh_knobs):
434434

435435

436436
@gluon.jit
437-
def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr, smem_layout: ttgl.constexpr):
438-
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout)
437+
def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr):
438+
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)
439439
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
440440
mbarrier.init(bar, count=1)
441441

@@ -455,25 +455,25 @@ def test_async_tma(fresh_knobs):
455455

456456
input = torch.randn((1024, 1024), device="cuda", dtype=torch.float16)
457457
XBLOCK = 128
458-
input_desc = TensorDescriptor.from_tensor(input, [XBLOCK, XBLOCK])
459458
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2)
459+
input_desc = TensorDescriptor.from_tensor(input, [XBLOCK, XBLOCK], shared_layout)
460460

461-
h = async_tma_kernel.warmup(input_desc, XBLOCK, shared_layout, grid=(1, ), num_warps=4)
461+
h = async_tma_kernel.warmup(input_desc, XBLOCK, grid=(1, ), num_warps=4)
462462
expecttest.assert_expected_inline(
463463
anonymize_ir(h.asm["source"]), """\
464464
#loc = loc(unknown)
465465
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
466466
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
467467
#smem = #ttg.shared_memory
468468
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
469-
tt.func public @async_tma_kernel(%arg0: !tt.tensordesc<tensor<128x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
469+
tt.func public @async_tma_kernel(%arg0: !tt.tensordesc<tensor<128x128xf16, #shared>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
470470
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
471471
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
472472
ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
473473
%c0_i32 = arith.constant 0 : i32 loc(#loc)
474474
%c0_i32_0 = arith.constant 0 : i32 loc(#loc)
475475
%true = arith.constant true loc(#loc)
476-
ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32_0] %0, %1, %true : !tt.tensordesc<tensor<128x128xf16>>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
476+
ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32_0] %0, %1, %true : !tt.tensordesc<tensor<128x128xf16, #shared>>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
477477
%true_1 = arith.constant true loc(#loc)
478478
ttng.barrier_expect %1, 32768, %true_1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
479479
%c0_i32_2 = arith.constant 0 : i32 loc(#loc)
@@ -482,7 +482,7 @@ def test_async_tma(fresh_knobs):
482482
ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
483483
%c0_i32_4 = arith.constant 0 : i32 loc(#loc)
484484
%c0_i32_5 = arith.constant 0 : i32 loc(#loc)
485-
ttng.async_tma_copy_local_to_global %arg0[%c0_i32_4, %c0_i32_5] %0 : !tt.tensordesc<tensor<128x128xf16>>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
485+
ttng.async_tma_copy_local_to_global %arg0[%c0_i32_4, %c0_i32_5] %0 : !tt.tensordesc<tensor<128x128xf16, #shared>>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
486486
ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc)
487487
tt.return loc(#loc)
488488
} loc(#loc)
@@ -491,8 +491,8 @@ def test_async_tma(fresh_knobs):
491491

492492

493493
@gluon.jit
494-
def async_tma_blackwell_kernel(input_desc, XBLOCK: ttgl.constexpr, smem_layout: ttgl.constexpr):
495-
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout)
494+
def async_tma_blackwell_kernel(input_desc, XBLOCK: ttgl.constexpr):
495+
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)
496496
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
497497
mbarrier.init(bar, count=1)
498498

@@ -514,10 +514,10 @@ def test_async_tma_blackwell(fresh_knobs):
514514

515515
input = torch.randn((1024, 1024), device="cuda", dtype=torch.float16)
516516
XBLOCK = 128
517-
input_desc = TensorDescriptor.from_tensor(input, [1, XBLOCK])
518517
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2)
518+
input_desc = TensorDescriptor.from_tensor(input, [1, XBLOCK], shared_layout)
519519

520-
h = async_tma_blackwell_kernel.warmup(input_desc, XBLOCK, shared_layout, grid=(1, ), num_warps=4)
520+
h = async_tma_blackwell_kernel.warmup(input_desc, XBLOCK, grid=(1, ), num_warps=4)
521521
expecttest.assert_expected_inline(
522522
anonymize_ir(h.asm["source"]), """\
523523
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
@@ -526,22 +526,22 @@ def test_async_tma_blackwell(fresh_knobs):
526526
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
527527
#smem = #ttg.shared_memory
528528
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
529-
tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc<tensor<1x128xf16>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
529+
tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc<tensor<1x128xf16, #shared>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
530530
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
531531
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
532532
ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
533533
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
534534
%true = arith.constant true loc(#loc)
535535
%c0_i32 = arith.constant 0 : i32 loc(#loc)
536-
ttng.async_tma_gather %arg0[%2, %c0_i32] %0, %1, %true : !tt.tensordesc<tensor<1x128xf16>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 loc(#loc)
536+
ttng.async_tma_gather %arg0[%2, %c0_i32] %0, %1, %true : !tt.tensordesc<tensor<1x128xf16, #shared>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 loc(#loc)
537537
%true_0 = arith.constant true loc(#loc)
538538
ttng.barrier_expect %1, 32768, %true_0 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
539539
%c0_i32_1 = arith.constant 0 : i32 loc(#loc)
540540
%true_2 = arith.constant true loc(#loc)
541541
ttng.wait_barrier %1, %c0_i32_1, %true_2 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
542542
ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
543543
%c0_i32_3 = arith.constant 0 : i32 loc(#loc)
544-
ttng.async_tma_scatter %arg0[%2, %c0_i32_3] %0 : !tt.tensordesc<tensor<1x128xf16>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
544+
ttng.async_tma_scatter %arg0[%2, %c0_i32_3] %0 : !tt.tensordesc<tensor<1x128xf16, #shared>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
545545
ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc)
546546
tt.return loc(#loc)
547547
} loc(#loc)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from . import nvidia
12
from ._runtime import jit
23

3-
__all__ = ["jit"]
4+
__all__ = ["jit", "nvidia"]

python/triton/experimental/gluon/language/_core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
dtype,
1717
block_type, # TODO: block type with layout info
1818
pointer_type,
19-
tuple_type,
2019
void,
2120
int1,
2221
int8,
@@ -39,6 +38,8 @@
3938
_unwrap_if_constexpr,
4039
_unwrap_shape,
4140
tensor,
41+
tuple,
42+
tuple_type,
4243
)
4344

4445
_IMPORT_FROM_TRITON: List[str] = [
@@ -88,6 +89,8 @@
8889
"float64",
8990
"_unwrap_if_constexpr",
9091
"tensor",
92+
"tuple",
93+
"tuple_type",
9194
"arange",
9295
"full",
9396
"convert_layout",

python/triton/experimental/gluon/language/_standard.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import triton
33
import triton.language.standard as tl_standard
44
from .._runtime import jit
5+
from triton import knobs
56

67
_IMPORT_FROM_TRITON = [
78
"sum",
@@ -16,5 +17,5 @@
1617
for name in _IMPORT_FROM_TRITON:
1718
# Convert JITFunction -> GluonJitFunction
1819
fn = getattr(tl_standard, name)
19-
assert isinstance(fn, triton.runtime.JITFunction)
20+
assert knobs.runtime.interpret or isinstance(fn, triton.runtime.JITFunction)
2021
globals()[name] = jit(fn.fn)

python/triton/experimental/gluon/language/nvidia/blackwell/tma.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
async_copy_global_to_shared,
44
async_copy_shared_to_global,
55
store_wait,
6+
tensor_descriptor,
7+
tensor_descriptor_type,
68
)
79

810
__all__ = [
@@ -11,6 +13,8 @@
1113
"async_copy_global_to_shared",
1214
"async_copy_shared_to_global",
1315
"store_wait",
16+
"tensor_descriptor",
17+
"tensor_descriptor_type",
1418
]
1519

1620

python/triton/experimental/gluon/language/nvidia/hopper/tma.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,81 @@
1+
from __future__ import annotations
2+
from typing import List, Tuple, TYPE_CHECKING
3+
from dataclasses import dataclass
4+
import triton.experimental.gluon.language._core as ttgl
5+
from triton.experimental.gluon.language._layouts import NVMMASharedLayout
16
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
27

8+
if TYPE_CHECKING:
9+
from triton._C import ir
10+
311
__all__ = ["async_copy_global_to_shared", "async_copy_shared_to_global", "store_wait"]
412

513

14+
@dataclass(eq=True)
15+
class tensor_descriptor_type:
16+
block_type: ttgl.block_type
17+
shape_type: ttgl.tuple_type
18+
strides_type: ttgl.tuple_type
19+
layout: NVMMASharedLayout
20+
21+
def __str__(self) -> str:
22+
return f"tensor_descriptor<{self.block_type}, {self.layout}>"
23+
24+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]:
25+
handle = handles[cursor]
26+
cursor += 1
27+
shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
28+
strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
29+
value = tensor_descriptor(handle, shape, strides, self.block_type, layout=self.layout)
30+
return value, cursor
31+
32+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
33+
is_signed = self.block_type.element_ty.is_int_signed()
34+
ty = builder.get_tensor_descriptor_layout_type(
35+
self.block_type.to_ir(builder),
36+
is_signed,
37+
self.layout._to_ir(builder),
38+
)
39+
out.append(ty)
40+
self.shape_type._flatten_ir_types(builder, out)
41+
self.strides_type._flatten_ir_types(builder, out)
42+
43+
def mangle(self) -> str:
44+
return f"TD{self.block_type.mangle}_{self.layout.mangle()}TD"
45+
46+
47+
class tensor_descriptor:
48+
49+
def __init__(self, handle, shape: List[ttgl.tensor], strides: List[ttgl.tensor], block_type: ttgl.block_type,
50+
layout: NVMMASharedLayout):
51+
self.handle = handle
52+
self.shape = ttgl.tuple(shape)
53+
self.strides = ttgl.tuple(strides)
54+
self.type = tensor_descriptor_type(block_type, shape_type=self.shape.type, strides_type=self.strides.type,
55+
layout=layout)
56+
57+
def _flatten_ir(self, handles: List[ir.value]) -> None:
58+
handles.append(self.handle)
59+
self.shape._flatten_ir(handles)
60+
self.strides._flatten_ir(handles)
61+
62+
@property
63+
def block_type(self):
64+
return self.type.block_type
65+
66+
@property
67+
def block_shape(self):
68+
return self.type.block_type.shape
69+
70+
@property
71+
def dtype(self):
72+
return self.type.block_type.element_ty
73+
74+
@property
75+
def layout(self):
76+
return self.type.layout
77+
78+
679
@builtin
780
def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True, _semantic=None):
881
coord = _semantic._convert_to_ir_values(coord, require_i64=False)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from . import hopper
2+
from . import blackwell
3+
4+
__all__ = ["hopper", "blackwell"]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .hopper import TensorDescriptor
2+
3+
__all__ = ["TensorDescriptor"]

0 commit comments

Comments
 (0)