Skip to content

Commit b6dabff

Browse files
authored
[Gluon] Fix memdesc_trans alloc shape inference and constexpr getitem (#7102)
1 parent e461a5b commit b6dabff

File tree

4 files changed

+35
-8
lines changed

4 files changed

+35
-8
lines changed

python/test/gluon/test_frontend.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,8 @@ def shared_memory_cast_kernel():
247247
rank=2)
248248
layout_T: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=True, element_bitwidth=8,
249249
rank=2)
250-
smem = ttgl.allocate_shared_memory(ttgl.int8, [256, 128], layout_a)
251-
smem.permute((1, 0), layout_T)
250+
smem = ttgl.allocate_shared_memory(ttgl.int8, [2, 256, 128], layout_a)
251+
smem.subslice(0).permute((1, 0), layout_T)
252252

253253
layout_b: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False, element_bitwidth=16,
254254
rank=4, cta_order=[3, 2, 1, 0])
@@ -271,11 +271,14 @@ def test_shared_memory_cast(fresh_knobs):
271271
#smem = #ttg.shared_memory
272272
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
273273
tt.func public @shared_memory_cast_kernel() attributes {noinline = false} {
274-
%0 = ttg.local_alloc : () -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable>
275-
%1 = ttg.memdesc_trans %0 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable>
276-
%2 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable>
277-
%3 = ttg.memdesc_reshape %2 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable, 32x1x4x64>
278-
%4 = ttg.memdesc_reinterpret %2 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable>
274+
%0 = ttg.local_alloc : () -> !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable>
275+
%c0_i32 = arith.constant 0 : i32
276+
%c0_i32_0 = arith.constant 0 : i32
277+
%1 = ttg.memdesc_subview %0[%c0_i32_0, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128>
278+
%2 = ttg.memdesc_trans %1 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>
279+
%3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable>
280+
%4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable, 32x1x4x64>
281+
%5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable>
279282
tt.return
280283
}
281284
}

python/test/unit/language/test_frontend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,23 @@ def test_reassign_aggregate_with_constexpr():
257257
agg = agg.modify(tl.arange(8, 12))
258258
# CHECK: call @{{.*}}anchor{{.*}}([[AGG]])
259259
anchor(agg)
260+
261+
262+
@tl.constexpr_function
263+
def make_shape(m, n):
264+
return (m, n)
265+
266+
267+
@tl.constexpr_function
268+
def add_shape_dims(m, n):
269+
return m + n
270+
271+
272+
@filecheck_test
273+
@triton.jit
274+
def test_constexpr_getitem():
275+
# CHECK-LABEL: test_constexpr_getitem
276+
# CHECK: make_range {end = 12 : i32, start = 4 : i32}
277+
shape: tl.constexpr = make_shape(4, 8)
278+
sum: tl.constexpr = add_shape_dims(shape[0], shape[1])
279+
tl.arange(4, sum)

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def memdesc_trans(self, mem_desc, order, layout):
164164
shape = [mem_desc.shape[i] for i in order]
165165
alloc_shape = mem_desc.type.alloc_shape
166166
new_alloc_shape = alloc_shape[:len(alloc_shape) - mem_desc.rank]
167-
new_alloc_shape += [alloc_shape[:mem_desc.rank][i] for i in order]
167+
new_alloc_shape += [alloc_shape[len(alloc_shape) - mem_desc.rank:][i] for i in order]
168168

169169
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, new_alloc_shape)
170170
handle = self.builder.create_memdesc_trans(ty.to_ir(self.builder), mem_desc.handle, order)

python/triton/language/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ def __iter__(self):
329329
def __call__(self, *args, **kwds):
330330
return self.value(*args, **kwds)
331331

332+
def __getitem__(self, *args):
333+
args = (_unwrap_if_constexpr(x) for x in _normalize_tuple(args))
334+
return self.value.__getitem__(*args)
335+
332336

333337
def constexpr_function(f):
334338
"""

0 commit comments

Comments
 (0)