Skip to content

Commit 5547174

Browse files
authored
[Frontend] Fix a few things in the frontend (#7060)
Fix assigning to tuples of other nodes, such as `a.x, b.y = unpack_me()` Fix `@constexpr_function` so that the functions can still be called from Python. Add `__eq__` to `constexpr_type` so that the parser can reconcile types of liveouts.
1 parent 72b2d9b commit 5547174

File tree

3 files changed

+67
-3
lines changed

3 files changed

+67
-3
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ def __init__(self, a, b):
179179
def create(a):
180180
return AggregateWithConstexpr(a, tl.constexpr(42))
181181

182+
@triton.jit
183+
def modify(self, a):
184+
self.a = a
185+
return self
186+
182187

183188
@triton.jit
184189
def add_rhs_constexpr(agg):
@@ -196,3 +201,59 @@ def test_aggregate_with_constexpr():
196201
# CHECK: tt.func private @"test_frontend.add_rhs_constexpr__test_frontend.AggregateWithConstexpr<i32S4S, constexpr[42]>
197202
# CHECK: %cst = arith.constant dense<42> : tensor<4xi32>
198203
# CHECK: arith.addi %arg0, %cst : tensor<4xi32>
204+
205+
206+
@tl.constexpr_function
207+
def constexpr_function(x):
208+
return x + 1
209+
210+
211+
@filecheck_test
212+
@triton.jit
213+
def test_constexpr_function_from_jit():
214+
# CHECK-LABEL: test_constexpr_function
215+
x: tl.constexpr = constexpr_function(7)
216+
# CHECK: make_range {end = 8 : i32, start = 0 : i32}
217+
tl.arange(0, x)
218+
219+
220+
def test_constexpr_function_from_python():
221+
assert constexpr_function(7) == 8
222+
223+
224+
@triton.jit
225+
def swap(pair):
226+
return pair.second, pair.first
227+
228+
229+
@filecheck_test
230+
@triton.jit
231+
def test_assign_tuple_attrs():
232+
# CHECK-LABEL: test_assign_tuple_attrs
233+
p = Pair(tl.arange(0, 4), tl.arange(4, 8))
234+
# CHECK: [[P:%.*]]:2 = tt.call @{{.*}}swap
235+
p.first, p.second = swap(p)
236+
# CHECK: call @{{.*}}anchor{{.*}}([[P]]#0)
237+
# CHECK: call @{{.*}}anchor{{.*}}([[P]]#1)
238+
anchor(p.first)
239+
anchor(p.second)
240+
241+
242+
@filecheck_test
243+
@triton.jit
244+
def test_reassign_aggregate_with_constexpr():
245+
# CHECK-LABEL: test_reassign_aggregate_with_constexpr
246+
agg = AggregateWithConstexpr.create(tl.arange(0, 4))
247+
var = 1
248+
# CHECK: [[AGG:%.*]] = scf.if {{.*}} -> (tensor<4xi32>)
249+
# CHECK: [[VALUE:%.*]] = tt.call {{.*}}modify
250+
# CHECK: yield [[VALUE]]
251+
# CHECK: else
252+
# CHECK: [[VALUE:%.*]] = tt.call {{.*}}modify
253+
# CHECK: yield [[VALUE]]
254+
if var == 0:
255+
agg = agg.modify(tl.arange(4, 8))
256+
else:
257+
agg = agg.modify(tl.arange(8, 12))
258+
# CHECK: call @{{.*}}anchor{{.*}}([[AGG]])
259+
anchor(agg)

python/triton/compiler/code_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,8 @@ def assignTarget(self, target, value):
571571
if isinstance(target, ast.Subscript):
572572
return self.visit_Subscript_Store(target, value)
573573
if isinstance(target, ast.Tuple):
574-
for i, name in enumerate(target.elts):
575-
self.set_value(self.visit(name), value.values[i])
574+
for i, target in enumerate(target.elts):
575+
self.assignTarget(target, value.values[i])
576576
return
577577
if isinstance(target, ast.Attribute):
578578
base = self.visit(target.value)

python/triton/language/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ class constexpr_type(base_type):
178178
def __init__(self, value):
179179
self.value = value
180180

181+
def __eq__(self, other):
182+
return self.value == other.value
183+
181184
def __repr__(self) -> str:
182185
return f"constexpr[{self.value}]"
183186

@@ -338,7 +341,7 @@ def constexpr_function(f):
338341
@wraps(f)
339342
def wrapper(*args, **kwargs):
340343
# de-constexpr arguments and discard the _builder keyword argument:
341-
args = [getattr(x, "value", x) for x in args]
344+
args = [_unwrap_if_constexpr(x) for x in args]
342345
kwargs = {k: getattr(v, "value", v) for (k, v) in kwargs.items() if k != "_builder"}
343346

344347
# call the raw Python function f:

0 commit comments

Comments
 (0)