Skip to content

Commit fe8ee0c

Browse files
authored
[Frontend] Fix augassign into an attribute (#7023)
This makes `foo.bar += 42` work
1 parent e74db2d commit fe8ee0c

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,21 @@ def test_assign_attribute():
4747
anchor(pair)
4848

4949

50+
@filecheck_test
51+
@triton.jit
52+
def test_augassign_attribute():
53+
# CHECK-LABEL: test_augassign_attribute
54+
# CHECK: %c11_i32 = arith.constant 11 : i32
55+
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
56+
scalar = 11
57+
pair = Pair(tl.arange(0, 4), scalar)
58+
# CHECK: %c42_i32 = arith.constant 42 : i32
59+
# CHECK: [[VALUE:%.*]] = arith.addi %c11_i32, %c42_i32
60+
pair.second += 42
61+
# CHECK-NEXT: call @"anchor{{.*}}"([[RANGE]], [[VALUE]])
62+
anchor(pair)
63+
64+
5065
@filecheck_test
5166
@triton.jit
5267
def test_jit_method():

python/triton/compiler/code_generator.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import copy
23
import inspect
34
import re
45
import warnings
@@ -565,16 +566,14 @@ def visit_AnnAssign(self, node):
565566
return self.visit_Assign(node)
566567

567568
def assignTarget(self, target, value):
569+
assert isinstance(target.ctx, ast.Store)
568570
if isinstance(target, ast.Subscript):
569-
assert target.ctx.__class__.__name__ == "Store"
570571
return self.visit_Subscript_Store(target, value)
571572
if isinstance(target, ast.Tuple):
572-
assert target.ctx.__class__.__name__ == "Store"
573573
for i, name in enumerate(target.elts):
574574
self.set_value(self.visit(name), value.values[i])
575575
return
576576
if isinstance(target, ast.Attribute):
577-
assert target.ctx.__class__.__name__ == "Store"
578577
base = self.visit(target.value)
579578
setattr(base, target.attr, value)
580579
return
@@ -600,12 +599,12 @@ def _sanitize_value(value):
600599
self.assignTarget(targets[0], values)
601600

602601
def visit_AugAssign(self, node):
603-
name = node.target.id
604-
lhs = ast.Name(id=name, ctx=ast.Load())
602+
lhs = copy.deepcopy(node.target)
603+
lhs.ctx = ast.Load()
605604
rhs = ast.BinOp(lhs, node.op, node.value)
606605
assign = ast.Assign(targets=[node.target], value=rhs)
607606
self.visit(assign)
608-
return self.dereference_name(name)
607+
return self.visit(lhs)
609608

610609
def visit_Name(self, node):
611610
if type(node.ctx) is ast.Store:
@@ -995,15 +994,15 @@ def visit_While(self, node):
995994
ast.NodeVisitor.generic_visit(self, stmt)
996995

997996
def visit_Subscript_Load(self, node):
998-
assert node.ctx.__class__.__name__ == "Load"
997+
assert isinstance(node.ctx, ast.Load)
999998
lhs = self.visit(node.value)
1000999
slices = self.visit(node.slice)
10011000
if _is_triton_tensor(lhs):
10021001
return lhs.__getitem__(slices, _builder=self.builder)
10031002
return lhs[slices]
10041003

10051004
def visit_Subscript_Store(self, node, value):
1006-
assert node.ctx.__class__.__name__ == "Store"
1005+
assert isinstance(node.ctx, ast.Store)
10071006
lhs = self.visit(node.value)
10081007
slices = self.visit(node.slice)
10091008
assert isinstance(lhs, language.tuple)

0 commit comments

Comments
 (0)