Skip to content

Commit 7a505f9

Browse files
authored
[Gluon] Fix layout mangling (#7020)
This fixes a few things to allow layout objects to be passed as constexprs to other functions
1 parent fb2e693 commit 7a505f9

File tree

5 files changed

+49
-3
lines changed

5 files changed

+49
-3
lines changed

python/test/gluon/test_frontend.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,3 +539,33 @@ def kernel():
539539
run_parser(kernel)
540540

541541
assert "order must be a permutation of 0..(rank-1), but was [1]" in str(e.value.__cause__)
542+
543+
544+
@gluon.jit
545+
def smem_and_layout_user(smem, a: ttgl.constexpr):
546+
pass
547+
548+
549+
def test_layout_mangling():
550+
551+
@gluon.jit
552+
def kernel():
553+
a: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0])
554+
smem = ttgl.allocate_shared_memory(ttgl.int32, [32, 32], a)
555+
smem_and_layout_user(smem, a)
556+
557+
expecttest.assert_expected_inline(
558+
run_parser(kernel).str_nodebug(), """\
559+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
560+
#smem = #ttg.shared_memory
561+
module {
562+
tt.func public @kernel() attributes {noinline = false} {
563+
%0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable>
564+
tt.call @"smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
565+
tt.return
566+
}
567+
tt.func private @"smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
568+
tt.return
569+
}
570+
}
571+
""")

python/triton/experimental/gluon/language/_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __neq__(self, other) -> bool:
167167
return not (self == other)
168168

169169
def mangle(self) -> str:
170-
shape_str = "_".join(self.shape)
170+
shape_str = "_".join([str(s) for s in self.shape])
171171
return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
172172

173173

@@ -290,6 +290,7 @@ def full(shape, value, dtype, layout, _builder=None):
290290
def allocate_shared_memory(element_ty, shape, layout, value=None, _builder=None):
291291
element_ty = _unwrap_if_constexpr(element_ty)
292292
shape = _unwrap_if_constexpr(shape)
293+
shape = [_unwrap_if_constexpr(s) for s in shape]
293294
layout = _unwrap_if_constexpr(layout)
294295
return semantic.allocate_shared(element_ty, shape, layout, value, _builder)
295296

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,4 +186,10 @@ def _to_ir(self, builder):
186186
)
187187

188188
def mangle(self) -> str:
189-
return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA"
189+
190+
def stringify(x):
191+
if x is None:
192+
return ""
193+
return "_".join(map(str, x))
194+
195+
return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_SSS"

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ def _to_ir(self, builder):
3939
cta_split_num,
4040
)
4141

42+
def mangle(self) -> str:
43+
block_str = f"{self.block[0]}x{self.block[1]}"
44+
unpacked_str = "U" if self.unpacked else "P"
45+
cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else ""
46+
return f"TL{block_str}{unpacked_str}{cta_split_str}TL"
47+
4248

4349
class tensor_memory_descriptor_type(base_type):
4450

@@ -75,7 +81,7 @@ def __neq__(self, other) -> bool:
7581
return not (self == other)
7682

7783
def mangle(self) -> str:
78-
shape_str = "_".join(self.shape)
84+
shape_str = "_".join([str(s) for s in self.shape])
7985
return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
8086

8187

python/triton/language/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,9 @@ def _flatten_ir(self, handles: List[ir.value]):
13081308
for v in self.values:
13091309
v._flatten_ir(handles)
13101310

1311+
def __repr__(self):
1312+
return f"({' ,'.join(repr(x) for x in self.values)})"
1313+
13111314

13121315
class slice:
13131316

0 commit comments

Comments
 (0)