Skip to content

Commit 09d5113

Browse files
authored
[Frontend] Fix scope enter to do a deep copy of scopes (#7271)
Follow up to #7200
1 parent ac84d71 commit 09d5113

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

python/triton/compiler/code_generator.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,26 @@ def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
106106
_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
107107

108108

109+
def _clone_triton_value(val):
110+
handles = []
111+
val._flatten_ir(handles)
112+
clone, _ = val.type._unflatten_ir(handles, 0)
113+
return clone
114+
115+
116+
def _clone_scope(scope):
117+
return {name: _clone_triton_value(val) if _is_triton_value(val) else val for name, val in scope.items()}
118+
119+
109120
class enter_sub_region:
110121

111122
def __init__(self, generator):
112123
self.generator = generator
113124

114125
def __enter__(self):
115126
# record lscope & local_defs in the parent scope
116-
self.liveins = self.generator.lscope.copy()
117-
self.prev_defs = self.generator.local_defs.copy()
127+
self.liveins = _clone_scope(self.generator.lscope)
128+
self.prev_defs = _clone_scope(self.generator.local_defs)
118129
self.generator.local_defs = {}
119130
self.insert_block = self.generator.builder.get_insertion_block()
120131
self.insert_point = self.generator.builder.get_insertion_point()
@@ -436,8 +447,6 @@ def _set_insertion_point_and_loc(self, ip, loc):
436447
self.builder.set_loc(loc)
437448

438449
def _find_carries(self, node, liveins):
439-
# We must extract the handles before the value is editted in the loop
440-
livehandles = {name: flatten_values_to_ir([v]) for name, v in liveins.items() if _is_triton_value(v)}
441450
# create loop body block
442451
block = self.builder.create_block()
443452
self.builder.set_insertion_point_to_start(block)
@@ -457,13 +466,11 @@ def _find_carries(self, node, liveins):
457466
for name, live_val in liveins.items():
458467
if _is_triton_value(live_val):
459468
loop_val = self.lscope[name]
460-
assert type(live_val) is type(loop_val), f'Loop carried variable {name} changed type'
469+
self._verify_loop_carried_variable(name, loop_val, live_val)
461470

462-
live_handles = livehandles[name]
471+
live_handles = flatten_values_to_ir([live_val])
463472
loop_handles = flatten_values_to_ir([loop_val])
464473
if live_handles != loop_handles:
465-
self._verify_loop_carried_variable(name, loop_val, live_val)
466-
467474
names.append(name)
468475
init_tys.append(live_val.type)
469476
init_handles.extend(live_handles)

0 commit comments

Comments
 (0)