Skip to content

Commit 268ead7

Browse files
authored
[Gluon] Fix memdesc_trans alloc shape (#7149)
`new_alloc_shape` was not being returned
1 parent b1f5c66 commit 268ead7

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

python/test/gluon/test_frontend.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ def shared_memory_cast_kernel():
252252
smem = ttgl.allocate_shared_memory(ttgl.int8, [2, 256, 128], layout_a)
253253
perm = smem.index(0).permute((1, 0))
254254
ttgl.static_assert(perm.type.layout == layout_T)
255+
# Check that the MLIR type and Gluon types match by emitting a call.
256+
anchor_noinline(perm)
255257

256258
layout_b: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False, element_bitwidth=16,
257259
rank=4, cta_order=[3, 2, 1, 0])
@@ -279,11 +281,15 @@ def test_shared_memory_cast(fresh_knobs):
279281
%c0_i32_0 = arith.constant 0 : i32
280282
%1 = ttg.memdesc_subview %0[%c0_i32_0, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128>
281283
%2 = ttg.memdesc_trans %1 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>
284+
tt.call @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) -> ()
282285
%3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable>
283286
%4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable, 32x1x4x64>
284287
%5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable>
285288
tt.return
286289
}
290+
tt.func private @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%arg0: !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) attributes {noinline = true} {
291+
tt.return
292+
}
287293
}
288294
""")
289295

@@ -318,6 +324,11 @@ def anchor(x):
318324
pass
319325

320326

327+
@gluon.jit(noinline=True)
328+
def anchor_noinline(x):
329+
pass
330+
331+
321332
@filecheck_test
322333
@gluon.jit
323334
def test_warp_specialize():

python/triton/experimental/gluon/language/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None
173173
out.append(self.to_ir(builder))
174174

175175
def __str__(self) -> str:
176-
return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}>"
176+
return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}, {self.alloc_shape}>"
177177

178178
def __eq__(self, other) -> bool:
179179
return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ def memdesc_trans(self, mem_desc, order):
191191

192192
handle = self.builder.create_memdesc_trans(mem_desc.handle, order)
193193
layout = self.builder.get_gluon_layout_from_memdesc(handle)
194-
return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape, alloc_shape=alloc_shape,
195-
layout=layout)
194+
return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape,
195+
alloc_shape=new_alloc_shape, layout=layout)
196196

197197
def memdesc_reshape(self, mem_desc, shape, layout):
198198
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)

0 commit comments

Comments
 (0)