Skip to content

Commit fed9ac4

Browse files
authored
[Gluon] Add tcgen05_mma and fix non-initialized allocate_tensor_memory (#6998)
1 parent 307680f commit fed9ac4

File tree

3 files changed

+83
-6
lines changed

3 files changed

+83
-6
lines changed

python/src/gluon_ir.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/IR/Types.h"
77
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
88
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
9+
#include "triton/Dialect/TritonGPU/IR/Types.h"
910
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1011

1112
using namespace mlir;
@@ -130,6 +131,10 @@ void init_gluon_ir(py::module &&m) {
130131
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
131132
return self.create<ttng::TMEMAllocOp>(resultTy, value);
132133
})
134+
.def("create_tmem_alloc",
135+
[](GluonOpBuilder &self, Type resultTy, py::none value) -> Value {
136+
return self.create<ttng::TMEMAllocOp>(resultTy, Value{});
137+
})
133138
.def("create_tmem_store",
134139
[](GluonOpBuilder &self, Value memDesc, Value value, Value pred) {
135140
self.create<ttng::TMEMStoreOp>(memDesc, value, pred);
@@ -164,6 +169,17 @@ void init_gluon_ir(py::module &&m) {
164169
[](GluonOpBuilder &self, Value memDesc, int count, Value pred) {
165170
self.create<ttng::ArriveBarrierOp>(memDesc, count, pred);
166171
})
172+
.def("create_tcgen05_mma",
173+
[](GluonOpBuilder &self, Value a, Value b, Value acc, Value useAcc,
174+
Value pred, std::vector<Value> &mbarriers,
175+
std::vector<Value> &mbarrier_preds) {
176+
Value accDep;
177+
bool two_ctas = false;
178+
auto tokType = self.getBuilder().getType<ttg::AsyncTokenType>();
179+
self.create<ttng::TCGen5MMAOp>(tokType, a, b, acc, accDep, useAcc,
180+
pred, two_ctas, mbarriers,
181+
mbarrier_preds);
182+
})
167183
.def("create_warp_return",
168184
[](GluonOpBuilder &self) -> Operation * {
169185
return self.create<ttg::WarpReturnOp>();

python/test/gluon/test_frontend.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from triton import knobs
66
from triton.experimental import gluon
77
from triton.experimental.gluon import language as ttgl
8+
from triton.experimental.gluon.language.nvidia import blackwell
89
from triton.experimental.gluon.language.nvidia.blackwell import mbarrier
910
from triton._filecheck import filecheck_test
1011
import triton.language as tl
@@ -85,6 +86,7 @@ def tensor_memory_kernel(layout: ttgl.constexpr, tmem_layout: ttgl.constexpr):
8586
XBLOCK: ttgl.constexpr = tmem_layout.block[0]
8687
YBLOCK: ttgl.constexpr = tmem_layout.block[1]
8788
a = ttgl.full([XBLOCK, YBLOCK], 0, ttgl.int32, layout)
89+
_ = ttgl.nvidia.blackwell.allocate_tensor_memory(ttgl.int32, a.shape, tmem_layout)
8890
mem = ttgl.nvidia.blackwell.allocate_tensor_memory(ttgl.int32, a.shape, tmem_layout, a)
8991
b = mem.load(layout) # noqa: F841
9092
mem.store(a)
@@ -108,12 +110,13 @@ def test_tensor_memory(fresh_knobs):
108110
tt.func public @tensor_memory_kernel() attributes {noinline = false} {
109111
%c0_i32 = arith.constant 0 : i32 loc(#loc)
110112
%cst = arith.constant dense<0> : tensor<128x128xi32, #blocked> loc(#loc)
111-
%result = ttng.tmem_alloc %cst : (tensor<128x128xi32, #blocked>) -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
112-
%result_0 = ttng.tmem_load %result : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xi32, #blocked> loc(#loc)
113+
%result = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
114+
%result_0 = ttng.tmem_alloc %cst : (tensor<128x128xi32, #blocked>) -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
115+
%result_1 = ttng.tmem_load %result_0 : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xi32, #blocked> loc(#loc)
113116
%true = arith.constant true loc(#loc)
114-
ttng.tmem_store %cst, %result, %true : tensor<128x128xi32, #blocked> -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
115-
%0 = ttng.tmem_subslice %result {N = 0 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
116-
%1 = ttng.tmem_subslice %result {N = 64 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
117+
ttng.tmem_store %cst, %result_0, %true : tensor<128x128xi32, #blocked> -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
118+
%0 = ttng.tmem_subslice %result_0 {N = 0 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
119+
%1 = ttng.tmem_subslice %result_0 {N = 64 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
117120
tt.return loc(#loc)
118121
} loc(#loc)
119122
} loc(#loc)
@@ -217,3 +220,40 @@ def test_mbarrier(fresh_knobs):
217220
} loc(#loc)
218221
#loc = loc(unknown)
219222
""")
223+
224+
225+
@gluon.jit
226+
def tcgen05_mma_kernel(nvmma_layout: ttgl.constexpr, acc_layout: ttgl.constexpr):
227+
a = ttgl.allocate_shared_memory(ttgl.float16, [128, 128], nvmma_layout)
228+
b = ttgl.allocate_shared_memory(ttgl.float16, [128, 128], nvmma_layout)
229+
acc = blackwell.allocate_tensor_memory(ttgl.float16, [128, 128], acc_layout)
230+
blackwell.tcgen05_mma(a, b, acc)
231+
232+
233+
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 10,
234+
reason="Requires blackwell tensor core")
235+
def test_tcgen05_mma(fresh_knobs):
236+
knobs.compilation.disable_line_info = True
237+
238+
nvmma_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2)
239+
acc_layout = blackwell.TensorMemoryLayout([128, 128], unpacked=True)
240+
241+
h = tcgen05_mma_kernel.warmup(nvmma_layout, acc_layout, grid=(1, ))
242+
expecttest.assert_expected_inline(
243+
h.asm["ttgir"], """\
244+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
245+
#smem = #ttg.shared_memory
246+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
247+
module attributes {"ttg.num-warps" = 4 : i32} {
248+
tt.func public @tcgen05_mma_kernel() attributes {noinline = false} {
249+
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
250+
%1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
251+
%result = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
252+
%true = arith.constant true loc(#loc)
253+
%true_0 = arith.constant true loc(#loc)
254+
%2 = ttng.tc_gen5_mma %0, %1, %result[], %true, %true_0 : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
255+
tt.return loc(#loc)
256+
} loc(#loc)
257+
} loc(#loc)
258+
#loc = loc(unknown)
259+
""")

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,28 @@ def allocate_tensor_memory(element_ty, shape, layout, value=None, _builder=None)
127127
element_ty = _unwrap_if_constexpr(element_ty)
128128
shape = _unwrap_if_constexpr(shape)
129129
layout = _unwrap_if_constexpr(layout)
130+
value = value.handle if value is not None else None
130131

131132
ty = tensor_memory_descriptor_type(element_ty, shape, layout, shape)
132-
handle = _builder.create_tmem_alloc(ty.to_ir(_builder), value.handle)
133+
handle = _builder.create_tmem_alloc(ty.to_ir(_builder), value)
133134
return tensor_memory_descriptor(handle, element_ty, shape, layout, shape)
135+
136+
137+
@builtin
138+
def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, mbarriers=None, mbarrier_preds=None, _builder=None):
139+
use_acc = ttgl.to_tensor(use_acc, _builder=_builder)
140+
pred = ttgl.to_tensor(pred, _builder=_builder)
141+
142+
if mbarriers is None:
143+
assert mbarrier_preds is None
144+
mbarriers = []
145+
mbarrier_preds = []
146+
else:
147+
mbarriers = [bar.handle for bar in mbarriers]
148+
if mbarrier_preds is None:
149+
true = ttgl.to_tensor(True, _builder=_builder)
150+
mbarrier_preds = [true] * len(mbarriers)
151+
else:
152+
mbarrier_preds = [pred.handle for pred in mbarrier_preds]
153+
154+
_builder.create_tcgen05_mma(a.handle, b.handle, acc.handle, use_acc.handle, pred.handle, mbarriers, mbarrier_preds)

0 commit comments

Comments
 (0)