Skip to content

Commit e594544

Browse files
authored
[Gluon] Infer memdesc_trans layout automatically (#7121)
1 parent 855fe3a commit e594544

File tree

4 files changed

+19
-12
lines changed

4 files changed

+19
-12
lines changed

python/src/gluon_ir.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,12 @@ void init_gluon_ir(py::module &&m) {
264264
assert(ty.getEncoding());
265265
return layoutToGluon(ty.getEncoding());
266266
})
267+
.def("get_gluon_layout_from_memdesc",
268+
[](GluonOpBuilder &self, Value memdesc) -> py::object {
269+
auto ty = dyn_cast<ttg::MemDescType>(memdesc.getType());
270+
assert(ty.getEncoding());
271+
return layoutToGluon(ty.getEncoding());
272+
})
267273
.def("create_convert_layout",
268274
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
269275
return self.create<ttg::ConvertLayoutOp>(resultTy, value);
@@ -296,9 +302,9 @@ void init_gluon_ir(py::module &&m) {
296302
offsets);
297303
})
298304
.def("create_memdesc_trans",
299-
[](GluonOpBuilder &self, Type resultType, Value src,
305+
[](GluonOpBuilder &self, Value src,
300306
std::vector<int> &order) -> Value {
301-
return self.create<ttg::MemDescTransOp>(resultType, src, order);
307+
return self.create<ttg::MemDescTransOp>(src, order);
302308
})
303309
.def("create_memdesc_reshape",
304310
[](GluonOpBuilder &self, Type resultType, Value src) -> Value {

python/test/gluon/test_frontend.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,11 @@ def shared_memory_cast_kernel():
247247
layout_a: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False, element_bitwidth=8,
248248
rank=2)
249249
layout_T: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=True, element_bitwidth=8,
250-
rank=2)
250+
rank=2, ctas_per_cga=[1, 1], cta_split_num=[1,
251+
1], cta_order=[1, 0])
251252
smem = ttgl.allocate_shared_memory(ttgl.int8, [2, 256, 128], layout_a)
252-
smem.subslice(0).permute((1, 0), layout_T)
253+
perm = smem.subslice(0).permute((1, 0))
254+
ttgl.static_assert(perm.type.layout == layout_T)
253255

254256
layout_b: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False, element_bitwidth=16,
255257
rank=4, cta_order=[3, 2, 1, 0])

python/triton/experimental/gluon/language/_core.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,9 @@ def subslice(self, index, shape=None, layout=None, _semantic: GluonSemantic = No
239239
return _semantic.memdesc_slice(self, index, shape, layout)
240240

241241
@builtin
242-
def permute(self, order, layout, _semantic: GluonSemantic) -> shared_memory_descriptor:
242+
def permute(self, order, _semantic: GluonSemantic) -> shared_memory_descriptor:
243243
order = [_unwrap_if_constexpr(o) for o in order]
244-
layout = _unwrap_if_constexpr(layout)
245-
246-
return _semantic.memdesc_trans(self, order, layout)
244+
return _semantic.memdesc_trans(self, order)
247245

248246
@builtin
249247
def reshape(self, shape, layout, _semantic: GluonSemantic) -> shared_memory_descriptor:

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def memdesc_slice(self, mem_desc, index, shape, layout):
163163
offsets[0] = self._convert_elem_to_ir_value(index, require_i64=False)
164164
return self._memdesc_subview(mem_desc, offsets, shape, layout)
165165

166-
def memdesc_trans(self, mem_desc, order, layout):
166+
def memdesc_trans(self, mem_desc, order):
167167
assert len(order) == len(
168168
mem_desc.shape), f"source rank ({mem_desc.rank}) and order length ({len(order)}) must match"
169169

@@ -172,9 +172,10 @@ def memdesc_trans(self, mem_desc, order, layout):
172172
new_alloc_shape = alloc_shape[:len(alloc_shape) - mem_desc.rank]
173173
new_alloc_shape += [alloc_shape[len(alloc_shape) - mem_desc.rank:][i] for i in order]
174174

175-
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, new_alloc_shape)
176-
handle = self.builder.create_memdesc_trans(ty.to_ir(self.builder), mem_desc.handle, order)
177-
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
175+
handle = self.builder.create_memdesc_trans(mem_desc.handle, order)
176+
layout = self.builder.get_gluon_layout_from_memdesc(handle)
177+
return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape, alloc_shape=alloc_shape,
178+
layout=layout)
178179

179180
def memdesc_reshape(self, mem_desc, shape, layout):
180181
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)

0 commit comments

Comments
 (0)