Skip to content

Commit 9f88c7f

Browse files
authored
[Gluon] Implement tensor memory (#6985)
This implements: - `ttgl.nvidia.blackwell.allocate_tensor` - `ttgl.nvidia.blackwell.TensorMemoryLayout` - `tensor_memory_descriptor.load` - `tensor_memory_descriptor.store` - `tensor_memory_descriptor.subslice`
1 parent 2b797c9 commit 9f88c7f

File tree

8 files changed

+251
-24
lines changed

8 files changed

+251
-24
lines changed

lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,10 @@ class TritonTensorMemoryAllocationPass
314314
: public impl::TritonTensorMemoryAllocationPassBase<
315315
TritonTensorMemoryAllocationPass> {
316316
public:
317+
IntegerAttr getI32Attr(int32_t value) {
318+
return Builder(&getContext()).getI32IntegerAttr(value);
319+
}
320+
317321
void runOnOperation() override {
318322
ModuleOp mod = getOperation();
319323
MLIRContext *ctx = &getContext();
@@ -323,6 +327,9 @@ class TritonTensorMemoryAllocationPass
323327
int totalMemorySize = allocateTMem(mod, offsets);
324328

325329
std::array<int, 6> possibleAllocations = {0, 32, 64, 128, 256, 512};
330+
// NOTE: if totalMemorySize > 512 we exceeded the maximum amount of tensor
331+
// memory, but we let the compilation finish so that we can raise an
332+
// exception in python for the auto-tuner.
326333
if (totalMemorySize <= 512) {
327334
for (int size : possibleAllocations) {
328335
if (totalMemorySize <= size) {
@@ -331,18 +338,18 @@ class TritonTensorMemoryAllocationPass
331338
}
332339
}
333340
}
334-
// if totalMemorySize > 512 we exceeded the maximum amount of tensor memory,
335-
// let the compilation finish so that we can raise an exception in python
336-
// for auto-tuner.
337341
if (totalMemorySize > 0) {
338-
assert(mod->getAttr("ttg.shared") != nullptr &&
339-
cast<IntegerAttr>(mod->getAttr("ttg.shared")).getInt() != 0 &&
340-
"Shared memory is required for allocation of Tensor Core memory.");
342+
// We use a small smem allocation to get the tensor memory base address
343+
// from tcgen05.alloc, ensure the block has at least 4 bytes of smem
344+
int shared = 0;
345+
if (auto sharedAttr = mod->getAttr("ttg.shared")) {
346+
shared = cast<IntegerAttr>(sharedAttr).getInt();
347+
}
348+
if (shared < 4) {
349+
mod->setAttr("ttg.shared", getI32Attr(4));
350+
}
341351
}
342-
343-
mod->setAttr("ttg.tensor_memory_size",
344-
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
345-
totalMemorySize));
352+
mod->setAttr("ttg.tensor_memory_size", getI32Attr(totalMemorySize));
346353
}
347354
};
348355

python/src/gluon_ir.cc

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44

55
#include "mlir/IR/BuiltinTypes.h"
66
#include "mlir/IR/Types.h"
7-
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
87
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
98
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
9+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1010

1111
using namespace mlir;
1212
namespace py = pybind11;
1313
namespace ttg = triton::gpu;
14+
namespace ttng = triton::nvidia_gpu;
1415

1516
struct GluonOpBuilder : public TritonOpBuilder {};
1617

@@ -35,6 +36,16 @@ void init_gluon_ir(py::module &&m) {
3536
/*mutableMemory=*/true,
3637
/*allocShape=*/allocShape);
3738
})
39+
.def("get_tensor_mem_desc_ty",
40+
[](GluonOpBuilder &self, Type &elementType,
41+
std::vector<int64_t> &shape, Attribute layout,
42+
std::vector<int64_t> &allocShape) -> Type {
43+
auto ctx = self.getContext();
44+
return ttg::MemDescType::get(shape, elementType, layout,
45+
ttng::TensorMemorySpaceAttr::get(ctx),
46+
/*mutableMemory=*/true,
47+
/*allocShape=*/allocShape);
48+
})
3849
.def("get_blocked_layout",
3950
[](GluonOpBuilder &self, std::vector<unsigned> &sizePerThread,
4051
std::vector<unsigned> &threadsPerWarp,
@@ -69,6 +80,16 @@ void init_gluon_ir(py::module &&m) {
6980
ctx, swizzleByteWidth, transposed, elementBitwidth, fp4Padded,
7081
ctaLayout);
7182
})
83+
.def("get_tensor_memory_layout",
84+
[](GluonOpBuilder &self, std::vector<unsigned> &block, bool unpacked,
85+
std::vector<unsigned> &ctaSplitNum) -> Attribute {
86+
auto ctx = self.getContext();
87+
assert(block.size() == 2);
88+
assert(ctaSplitNum.size() == 2);
89+
return ttng::TensorMemoryEncodingAttr::get(
90+
ctx, block[0], block[1], unpacked, ctaSplitNum[0],
91+
ctaSplitNum[1]);
92+
})
7293
.def("create_convert_layout",
7394
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
7495
return self.create<ttg::ConvertLayoutOp>(resultTy, value);
@@ -85,7 +106,23 @@ void init_gluon_ir(py::module &&m) {
85106
[](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value {
86107
return self.create<ttg::LocalLoadOp>(resultTy, memDesc);
87108
})
88-
109+
.def("create_tmem_alloc",
110+
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
111+
return self.create<ttng::TMEMAllocOp>(resultTy, value);
112+
})
113+
.def("create_tmem_store",
114+
[](GluonOpBuilder &self, Value memDesc, Value value, Value pred) {
115+
self.create<ttng::TMEMStoreOp>(memDesc, value, pred);
116+
})
117+
.def("create_tmem_load",
118+
[](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value {
119+
return self.create<ttng::TMEMLoadOp>(resultTy, memDesc);
120+
})
121+
.def("create_tmem_subslice",
122+
[](GluonOpBuilder &self, Type resultTy, Value memDesc,
123+
int N) -> Value {
124+
return self.create<ttng::TMEMSubSliceOp>(resultTy, memDesc, N);
125+
})
89126
.def("create_warp_return",
90127
[](GluonOpBuilder &self) -> Operation * {
91128
return self.create<ttg::WarpReturnOp>();

python/test/gluon/test_frontend.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import expecttest
2+
import torch
3+
import pytest
24

35
from triton import knobs
46
from triton.experimental import gluon
57
from triton.experimental.gluon import language as ttgl
68
from triton._filecheck import filecheck_test
79
import triton.language as tl
10+
from triton._internal_testing import is_cuda
811

912

1013
@gluon.jit
@@ -39,7 +42,7 @@ def test_convert_layout(fresh_knobs):
3942
def shared_memory_kernel(XBLOCK: ttgl.constexpr, YBLOCK: ttgl.constexpr, layout_a: ttgl.constexpr,
4043
layout_b: ttgl.constexpr, smem_layout: ttgl.constexpr):
4144
a = ttgl.full([XBLOCK, YBLOCK], 0, ttgl.int32, layout_a)
42-
mem = ttgl.allocate_shared(ttgl.int32, a.shape, smem_layout, a)
45+
mem = ttgl.allocate_shared_memory(ttgl.int32, a.shape, smem_layout, a)
4346
b = mem.load(layout_b) # noqa: F841
4447
mem.store(a)
4548

@@ -72,6 +75,47 @@ def test_shared_memory(fresh_knobs):
7275
""")
7376

7477

78+
@gluon.jit
79+
def tensor_memory_kernel(layout: ttgl.constexpr, tmem_layout: ttgl.constexpr):
80+
XBLOCK: ttgl.constexpr = tmem_layout.block[0]
81+
YBLOCK: ttgl.constexpr = tmem_layout.block[1]
82+
a = ttgl.full([XBLOCK, YBLOCK], 0, ttgl.int32, layout)
83+
mem = ttgl.nvidia.blackwell.allocate_tensor_memory(ttgl.int32, a.shape, tmem_layout, a)
84+
b = mem.load(layout) # noqa: F841
85+
mem.store(a)
86+
slice1 = mem.subslice(0, YBLOCK // 2) # noqa: F841
87+
slice2 = mem.subslice(YBLOCK // 2, YBLOCK // 2) # noqa: F841
88+
89+
90+
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 10,
91+
reason="Requires blackwell tensor cores")
92+
def test_tensor_memory(fresh_knobs):
93+
knobs.compilation.disable_line_info = True
94+
95+
layout = ttgl.BlockedLayout(size_per_thread=[1, 64], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1])
96+
tmem_layout = ttgl.nvidia.blackwell.TensorMemoryLayout(block=[128, 128], unpacked=True)
97+
h = tensor_memory_kernel.warmup(layout, tmem_layout, num_warps=4, grid=(1, ))
98+
expecttest.assert_expected_inline(
99+
h.asm["ttgir"], """\
100+
#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
101+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
102+
module attributes {"ttg.num-warps" = 4 : i32} {
103+
tt.func public @tensor_memory_kernel() attributes {noinline = false} {
104+
%c0_i32 = arith.constant 0 : i32 loc(#loc)
105+
%cst = arith.constant dense<0> : tensor<128x128xi32, #blocked> loc(#loc)
106+
%result = ttng.tmem_alloc %cst : (tensor<128x128xi32, #blocked>) -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
107+
%result_0 = ttng.tmem_load %result : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xi32, #blocked> loc(#loc)
108+
%true = arith.constant true loc(#loc)
109+
ttng.tmem_store %cst, %result, %true : tensor<128x128xi32, #blocked> -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
110+
%0 = ttng.tmem_subslice %result {N = 0 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
111+
%1 = ttng.tmem_subslice %result {N = 64 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
112+
tt.return loc(#loc)
113+
} loc(#loc)
114+
} loc(#loc)
115+
#loc = loc(unknown)
116+
""")
117+
118+
75119
@gluon.jit
76120
def warp_specialize_default(a, b):
77121
return b, a

python/triton/compiler/code_generator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,6 +1349,9 @@ def visit_Attribute(self, node):
13491349
lhs = self.visit(node.value)
13501350
if _is_triton_tensor(lhs) and node.attr == "T":
13511351
return semantic.permute(lhs, (1, 0), builder=self.builder)
1352+
# NOTE: special case ".value" for BC
1353+
if isinstance(lhs, constexpr) and node.attr != "value":
1354+
lhs = lhs.value
13521355
attr = getattr(lhs, node.attr)
13531356
if _is_triton_value(lhs) and isinstance(attr, JITFunction):
13541357
return BoundJITMethod(lhs, attr)

python/triton/experimental/gluon/language/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,10 @@
33
from ._layouts import * # NOQA: F403
44
from ._layouts import __all__ as __layouts_all
55

6-
__all__ = [*__core_all, *__layouts_all]
6+
from . import nvidia
7+
8+
__all__ = [
9+
*__core_all,
10+
*__layouts_all,
11+
"nvidia",
12+
]

python/triton/experimental/gluon/language/_core.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@
4141
)
4242
from . import _semantic as semantic
4343

44+
_IMPORT_FROM_TRITON: List[str] = [
45+
"program_id", # NOQA: F822
46+
"load", # NOQA: F822
47+
"store", # NOQA: F822
48+
"to_tensor", # NOQA: F822
49+
]
50+
4451
__all__ = [
4552
"constexpr",
4653
"base_value",
@@ -71,15 +78,13 @@
7178
"float64",
7279
"_unwrap_if_constexpr",
7380
"tensor",
74-
"program_id", # NOQA: F822
75-
"load", # NOQA: F822
76-
"store", # NOQA: F822
7781
"arange",
7882
"full",
7983
"convert_layout",
80-
"allocate_shared",
84+
"allocate_shared_memory",
8185
"shared_memory_descriptor",
8286
"warp_specialize",
87+
*_IMPORT_FROM_TRITON,
8388
]
8489

8590
T = TypeVar("T")
@@ -196,11 +201,7 @@ def store(self, value, _builder: GluonOpBuilder) -> None:
196201
return semantic.shared_store(self, value, _builder)
197202

198203

199-
for name in [
200-
"program_id",
201-
"load",
202-
"store",
203-
]:
204+
for name in _IMPORT_FROM_TRITON:
204205
fn = getattr(tl_core, name)
205206
globals()[name] = builtin(fn)
206207

@@ -229,7 +230,7 @@ def full(shape, value, dtype, layout, _builder=None):
229230

230231

231232
@builtin
232-
def allocate_shared(element_ty, shape, layout, value=None, _builder=None):
233+
def allocate_shared_memory(element_ty, shape, layout, value=None, _builder=None):
233234
element_ty = _unwrap_if_constexpr(element_ty)
234235
shape = _unwrap_if_constexpr(shape)
235236
layout = _unwrap_if_constexpr(layout)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from . import blackwell
2+
3+
__all__ = ["blackwell"]

0 commit comments

Comments
 (0)