Skip to content

Commit 5b50c7f

Browse files
authored
[Frontend] Allow @_aggregate types to contain constexprs (#7024)
Constexpr fields don't have an IR representation and are carried through the type. Also ensure that they are mangled into the type name. This PR achieves this by adding a `constexpt_type` that contains the constexpr value itself and makes this the type of `constexpr`. Shoutout to @peterbell10 for the suggestion
1 parent e4c7fe8 commit 5b50c7f

File tree

2 files changed

+89
-36
lines changed

2 files changed

+89
-36
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,35 @@ def test_call_in_loop():
141141
# CHECK: call @accumulate
142142
for i in range(10):
143143
acc = accumulate(acc, i)
144+
145+
146+
@tl.core._aggregate
147+
class AggregateWithConstexpr:
148+
a: tl.tensor
149+
b: tl.constexpr
150+
151+
def __init__(self, a, b):
152+
self.a = a
153+
self.b = b
154+
155+
@staticmethod
156+
def create(a):
157+
return AggregateWithConstexpr(a, tl.constexpr(42))
158+
159+
160+
@triton.jit
161+
def add_rhs_constexpr(agg):
162+
_ = agg.a + agg.b
163+
164+
165+
@filecheck_test
166+
@triton.jit
167+
def test_aggregate_with_constexpr():
168+
# CHECK-LABEL: test_aggregate_with_constexpr
169+
# CHECK: tt.call @"add_rhs_constexpr__test_frontend.AggregateWithConstexpr<i32S4S, constexpr[42]>
170+
agg = AggregateWithConstexpr.create(tl.arange(0, 4))
171+
add_rhs_constexpr(agg)
172+
173+
# CHECK: tt.func private @"add_rhs_constexpr__test_frontend.AggregateWithConstexpr<i32S4S, constexpr[42]>
174+
# CHECK: %cst = arith.constant dense<42> : tensor<4xi32>
175+
# CHECK: arith.addi %arg0, %cst : tensor<4xi32>

python/triton/language/core.py

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,59 @@ class const:
139139
pass
140140

141141

142-
class constexpr:
142+
class base_value:
143+
"""Base class of values that exist in the triton IR (i.e. not constexprs).
144+
"""
145+
type: base_type
146+
147+
def _flatten_ir(self, handles: List[ir.value]) -> None:
148+
"""Flatten frontend value into a sequence of mlir handles, which are appended
149+
to the output list
150+
"""
151+
raise NotImplementedError
152+
153+
154+
class base_type:
155+
156+
def __eq__(self, other):
157+
raise NotImplementedError("Types must implement __eq__")
158+
159+
def __ne__(self, other):
160+
return not (self == other)
161+
162+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
163+
"""Build a frontend value with the current dtype, wrapping a list of existing handles.
164+
cursor is the index of the first handle relevant to this value, and the function
165+
should return the updated cursor position after any handles consumed by the created value.
166+
"""
167+
raise NotImplementedError
168+
169+
def mangle(self) -> str:
170+
raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}")
171+
172+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
173+
raise NotImplementedError
174+
175+
176+
class constexpr_type(base_type):
177+
178+
def __init__(self, value):
179+
self.value = value
180+
181+
def __repr__(self) -> str:
182+
return f"constexpr[{self.value}]"
183+
184+
def mangle(self) -> str:
185+
return repr(self)
186+
187+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
188+
return
189+
190+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
191+
return constexpr(self.value), cursor
192+
193+
194+
class constexpr(base_value):
143195
"""
144196
This class is used to store a value that is known at compile-time.
145197
"""
@@ -149,11 +201,14 @@ def __init__(self, value):
149201
self.value = value.value
150202
else:
151203
self.value = value
152-
self.type = constexpr
204+
self.type = constexpr_type(value)
153205

154206
def __repr__(self) -> str:
155207
return f"constexpr[{self.value}]"
156208

209+
def _flatten_ir(self, handles: List[ir.value]) -> None:
210+
return
211+
157212
def __index__(self):
158213
return self.value
159214

@@ -322,40 +377,6 @@ def check_bit_width(value, shift_value):
322377
)
323378

324379

325-
class base_value:
326-
"""Base class of values that exist in the triton IR (i.e. not constexprs).
327-
"""
328-
type: base_type
329-
330-
def _flatten_ir(self, handles: List[ir.value]) -> None:
331-
"""Flatten frontend value into a sequence of mlir handles, which are appended
332-
to the output list
333-
"""
334-
raise NotImplementedError
335-
336-
337-
class base_type:
338-
339-
def __eq__(self, other):
340-
raise NotImplementedError("Types must implement __eq__")
341-
342-
def __ne__(self, other):
343-
return not (self == other)
344-
345-
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
346-
"""Build a frontend value with the current dtype, wrapping a list of existing handles.
347-
cursor is the index of the first handle relevant to this value, and the function
348-
should return the updated cursor position after any handles consumed by the created value.
349-
"""
350-
raise NotImplementedError
351-
352-
def mangle(self) -> str:
353-
raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}")
354-
355-
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
356-
raise NotImplementedError
357-
358-
359380
# -----------------------
360381
# dtype
361382
# -----------------------

0 commit comments

Comments
 (0)