1
1
import ast
2
+ import copy
2
3
import inspect
3
4
import re
4
5
import warnings
@@ -565,16 +566,14 @@ def visit_AnnAssign(self, node):
565
566
return self .visit_Assign (node )
566
567
567
568
def assignTarget (self , target , value ):
569
+ assert isinstance (target .ctx , ast .Store )
568
570
if isinstance (target , ast .Subscript ):
569
- assert target .ctx .__class__ .__name__ == "Store"
570
571
return self .visit_Subscript_Store (target , value )
571
572
if isinstance (target , ast .Tuple ):
572
- assert target .ctx .__class__ .__name__ == "Store"
573
573
for i , name in enumerate (target .elts ):
574
574
self .set_value (self .visit (name ), value .values [i ])
575
575
return
576
576
if isinstance (target , ast .Attribute ):
577
- assert target .ctx .__class__ .__name__ == "Store"
578
577
base = self .visit (target .value )
579
578
setattr (base , target .attr , value )
580
579
return
@@ -600,12 +599,12 @@ def _sanitize_value(value):
600
599
self .assignTarget (targets [0 ], values )
601
600
602
601
def visit_AugAssign (self , node ):
603
- name = node .target . id
604
- lhs = ast .Name ( id = name , ctx = ast . Load () )
602
+ lhs = copy . deepcopy ( node .target )
603
+ lhs . ctx = ast .Load ()
605
604
rhs = ast .BinOp (lhs , node .op , node .value )
606
605
assign = ast .Assign (targets = [node .target ], value = rhs )
607
606
self .visit (assign )
608
- return self .dereference_name ( name )
607
+ return self .visit ( lhs )
609
608
610
609
def visit_Name (self , node ):
611
610
if type (node .ctx ) is ast .Store :
@@ -995,15 +994,15 @@ def visit_While(self, node):
995
994
ast .NodeVisitor .generic_visit (self , stmt )
996
995
997
996
def visit_Subscript_Load (self , node ):
998
- assert node .ctx . __class__ . __name__ == " Load"
997
+ assert isinstance ( node .ctx , ast . Load )
999
998
lhs = self .visit (node .value )
1000
999
slices = self .visit (node .slice )
1001
1000
if _is_triton_tensor (lhs ):
1002
1001
return lhs .__getitem__ (slices , _builder = self .builder )
1003
1002
return lhs [slices ]
1004
1003
1005
1004
def visit_Subscript_Store (self , node , value ):
1006
- assert node .ctx . __class__ . __name__ == " Store"
1005
+ assert isinstance ( node .ctx , ast . Store )
1007
1006
lhs = self .visit (node .value )
1008
1007
slices = self .visit (node .slice )
1009
1008
assert isinstance (lhs , language .tuple )
0 commit comments