Skip to content

Commit d510a3d

Browse files
authored
[Gluon] Add local_dealloc and enable local_alloc with no initializer (#6994)
`local_dealloc` is an unfortunate necessity because the compiler doesn't correctly reason about the liveranges of shared memory used by async operations. For now, users will need to manually keep shared memory alive using `smem._keep_alive()`.
1 parent 9f88c7f commit d510a3d

File tree

4 files changed

+28
-4
lines changed

4 files changed

+28
-4
lines changed

python/src/gluon_ir.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ void init_gluon_ir(py::module &&m) {
9494
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
9595
return self.create<ttg::ConvertLayoutOp>(resultTy, value);
9696
})
97+
.def("create_local_alloc",
98+
[](GluonOpBuilder &self, Type resultTy) -> Value {
99+
return self.create<ttg::LocalAllocOp>(resultTy);
100+
})
97101
.def("create_local_alloc",
98102
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
99103
return self.create<ttg::LocalAllocOp>(resultTy, value);
@@ -106,6 +110,11 @@ void init_gluon_ir(py::module &&m) {
106110
[](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value {
107111
return self.create<ttg::LocalLoadOp>(resultTy, memDesc);
108112
})
113+
.def("create_local_dealloc",
114+
[](GluonOpBuilder &self, Value memDesc) -> Operation * {
115+
return self.create<ttg::LocalDeallocOp>(memDesc);
116+
})
117+
109118
.def("create_tmem_alloc",
110119
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
111120
return self.create<ttng::TMEMAllocOp>(resultTy, value);

python/test/gluon/test_frontend.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@ def test_convert_layout(fresh_knobs):
4141
@gluon.jit
4242
def shared_memory_kernel(XBLOCK: ttgl.constexpr, YBLOCK: ttgl.constexpr, layout_a: ttgl.constexpr,
4343
layout_b: ttgl.constexpr, smem_layout: ttgl.constexpr):
44+
unused = ttgl.allocate_shared_memory(ttgl.int32, [XBLOCK, YBLOCK], smem_layout)
4445
a = ttgl.full([XBLOCK, YBLOCK], 0, ttgl.int32, layout_a)
4546
mem = ttgl.allocate_shared_memory(ttgl.int32, a.shape, smem_layout, a)
4647
b = mem.load(layout_b) # noqa: F841
4748
mem.store(a)
49+
unused._keep_alive()
4850

4951

5052
def test_shared_memory(fresh_knobs):
@@ -63,11 +65,13 @@ def test_shared_memory(fresh_knobs):
6365
#smem = #ttg.shared_memory
6466
module attributes {"ttg.num-warps" = 4 : i32} {
6567
tt.func public @shared_memory_kernel() attributes {noinline = false} {
68+
%0 = ttg.local_alloc : () -> !ttg.memdesc<8x32xi32, #shared, #smem, mutable> loc(#loc)
6669
%c0_i32 = arith.constant 0 : i32 loc(#loc)
6770
%cst = arith.constant dense<0> : tensor<8x32xi32, #blocked> loc(#loc)
68-
%0 = ttg.local_alloc %cst : (tensor<8x32xi32, #blocked>) -> !ttg.memdesc<8x32xi32, #shared, #smem, mutable> loc(#loc)
69-
%1 = ttg.local_load %0 : !ttg.memdesc<8x32xi32, #shared, #smem, mutable> -> tensor<8x32xi32, #blocked1> loc(#loc)
70-
ttg.local_store %cst, %0 : tensor<8x32xi32, #blocked> -> !ttg.memdesc<8x32xi32, #shared, #smem, mutable> loc(#loc)
71+
%1 = ttg.local_alloc %cst : (tensor<8x32xi32, #blocked>) -> !ttg.memdesc<8x32xi32, #shared, #smem, mutable> loc(#loc)
72+
%2 = ttg.local_load %1 : !ttg.memdesc<8x32xi32, #shared, #smem, mutable> -> tensor<8x32xi32, #blocked1> loc(#loc)
73+
ttg.local_store %cst, %1 : tensor<8x32xi32, #blocked> -> !ttg.memdesc<8x32xi32, #shared, #smem, mutable> loc(#loc)
74+
ttg.local_dealloc %0 : !ttg.memdesc<8x32xi32, #shared, #smem, mutable> loc(#loc)
7175
tt.return loc(#loc)
7276
} loc(#loc)
7377
} loc(#loc)

python/triton/experimental/gluon/language/_core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ def load(self, layout, _builder: GluonOpBuilder) -> tensor:
200200
def store(self, value, _builder: GluonOpBuilder) -> None:
201201
return semantic.shared_store(self, value, _builder)
202202

203+
@builtin
204+
def _keep_alive(self, _builder=None) -> None:
205+
return semantic.shared_dealloc(self, _builder)
206+
203207

204208
for name in _IMPORT_FROM_TRITON:
205209
fn = getattr(tl_core, name)

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ def convert_layout(value, layout, builder: GluonOpBuilder):
3232

3333
def allocate_shared(element_ty, shape, layout, value, builder: GluonOpBuilder):
3434
ty = ttgl.shared_memory_descriptor_type(element_ty, shape, layout, shape)
35-
handle = builder.create_local_alloc(ty.to_ir(builder), value.handle)
35+
if value is not None:
36+
handle = builder.create_local_alloc(ty.to_ir(builder), value.handle)
37+
else:
38+
handle = builder.create_local_alloc(ty.to_ir(builder))
3639
return ttgl.shared_memory_descriptor(handle, element_ty, shape, layout, shape)
3740

3841

@@ -46,6 +49,10 @@ def shared_store(mem_desc, value, builder: GluonOpBuilder):
4649
builder.create_local_store(mem_desc.handle, value.handle)
4750

4851

52+
def shared_dealloc(mem_desc, builder: GluonOpBuilder):
53+
builder.create_local_dealloc(mem_desc.handle)
54+
55+
4956
def warp_specialize(args, default_partition, worker_partitions, worker_num_warps: Sequence[int],
5057
worker_num_regs: Sequence[int], builder: GluonOpBuilder, generator):
5158
num_partitions = len(worker_partitions)

0 commit comments

Comments
 (0)