Skip to content

Commit ac56e2f

Browse files
authored
[FRONTEND] Fix mangling for tuples (triton-lang#8060)
1 parent c6eae40 commit ac56e2f

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,34 @@ def test_aggregate_with_constexpr():
308308
# CHECK: arith.addi %arg0, %cst : tensor<4xi32>
309309

310310

311+
@tl.core._aggregate
312+
class AggregateWithTuple:
313+
a: tl.tuple
314+
315+
def __init__(self, a):
316+
self.a = tl.tuple((a, ))
317+
318+
@staticmethod
319+
@triton.jit
320+
def create(a):
321+
return AggregateWithTuple(a)
322+
323+
324+
@triton.jit
325+
def pass_tuple_aggregate(agg):
326+
pass
327+
328+
329+
@filecheck_test
330+
@triton.jit
331+
def test_aggregate_with_tuple():
332+
# CHECK-LABEL: test_aggregate_with_tuple
333+
# CHECK: tt.call @"test_frontend.pass_tuple_aggregate__test_frontend.AggregateWithTuple<Ti32S4ST>__"
334+
agg = AggregateWithTuple.create(tl.arange(0, 4))
335+
pass_tuple_aggregate(agg)
336+
# CHECK: tt.func private @"test_frontend.pass_tuple_aggregate__test_frontend.AggregateWithTuple<Ti32S4ST>__"
337+
338+
311339
@triton.constexpr_function
312340
def constexpr_function(x):
313341
return x + 1

python/triton/language/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,7 @@ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, in
777777
return tuple(values, self), cursor
778778

779779
def mangle(self):
780-
return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T'
780+
return 'T' + '_'.join(ty.mangle() for ty in self.types) + 'T'
781781

782782

783783
class slice_type(dtype):

0 commit comments

Comments
 (0)