@@ -106,15 +106,26 @@ def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
106
106
_condition_types = {bool , int , type (None )} # Python types accepted for conditionals inside kernels
107
107
108
108
109
+ def _clone_triton_value (val ):
110
+ handles = []
111
+ val ._flatten_ir (handles )
112
+ clone , _ = val .type ._unflatten_ir (handles , 0 )
113
+ return clone
114
+
115
+
116
+ def _clone_scope (scope ):
117
+ return {name : _clone_triton_value (val ) if _is_triton_value (val ) else val for name , val in scope .items ()}
118
+
119
+
109
120
class enter_sub_region :
110
121
111
122
def __init__ (self , generator ):
112
123
self .generator = generator
113
124
114
125
def __enter__ (self ):
115
126
# record lscope & local_defs in the parent scope
116
- self .liveins = self .generator .lscope . copy ( )
117
- self .prev_defs = self .generator .local_defs . copy ( )
127
+ self .liveins = _clone_scope ( self .generator .lscope )
128
+ self .prev_defs = _clone_scope ( self .generator .local_defs )
118
129
self .generator .local_defs = {}
119
130
self .insert_block = self .generator .builder .get_insertion_block ()
120
131
self .insert_point = self .generator .builder .get_insertion_point ()
@@ -436,8 +447,6 @@ def _set_insertion_point_and_loc(self, ip, loc):
436
447
self .builder .set_loc (loc )
437
448
438
449
def _find_carries (self , node , liveins ):
439
- # We must extract the handles before the value is editted in the loop
440
- livehandles = {name : flatten_values_to_ir ([v ]) for name , v in liveins .items () if _is_triton_value (v )}
441
450
# create loop body block
442
451
block = self .builder .create_block ()
443
452
self .builder .set_insertion_point_to_start (block )
@@ -457,13 +466,11 @@ def _find_carries(self, node, liveins):
457
466
for name , live_val in liveins .items ():
458
467
if _is_triton_value (live_val ):
459
468
loop_val = self .lscope [name ]
460
- assert type ( live_val ) is type ( loop_val ), f'Loop carried variable { name } changed type'
469
+ self . _verify_loop_carried_variable ( name , loop_val , live_val )
461
470
462
- live_handles = livehandles [ name ]
471
+ live_handles = flatten_values_to_ir ([ live_val ])
463
472
loop_handles = flatten_values_to_ir ([loop_val ])
464
473
if live_handles != loop_handles :
465
- self ._verify_loop_carried_variable (name , loop_val , live_val )
466
-
467
474
names .append (name )
468
475
init_tys .append (live_val .type )
469
476
init_handles .extend (live_handles )
0 commit comments