@@ -106,15 +106,26 @@ def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
106106_condition_types = {bool , int , type (None )} # Python types accepted for conditionals inside kernels
107107
108108
109+ def _clone_triton_value (val ):
110+ handles = []
111+ val ._flatten_ir (handles )
112+ clone , _ = val .type ._unflatten_ir (handles , 0 )
113+ return clone
114+
115+
116+ def _clone_scope (scope ):
117+ return {name : _clone_triton_value (val ) if _is_triton_value (val ) else val for name , val in scope .items ()}
118+
119+
109120class enter_sub_region :
110121
111122 def __init__ (self , generator ):
112123 self .generator = generator
113124
114125 def __enter__ (self ):
115126 # record lscope & local_defs in the parent scope
116- self .liveins = self .generator .lscope . copy ( )
117- self .prev_defs = self .generator .local_defs . copy ( )
127+ self .liveins = _clone_scope ( self .generator .lscope )
128+ self .prev_defs = _clone_scope ( self .generator .local_defs )
118129 self .generator .local_defs = {}
119130 self .insert_block = self .generator .builder .get_insertion_block ()
120131 self .insert_point = self .generator .builder .get_insertion_point ()
@@ -436,8 +447,6 @@ def _set_insertion_point_and_loc(self, ip, loc):
436447 self .builder .set_loc (loc )
437448
438449 def _find_carries (self , node , liveins ):
439- # We must extract the handles before the value is editted in the loop
440- livehandles = {name : flatten_values_to_ir ([v ]) for name , v in liveins .items () if _is_triton_value (v )}
441450 # create loop body block
442451 block = self .builder .create_block ()
443452 self .builder .set_insertion_point_to_start (block )
@@ -457,13 +466,11 @@ def _find_carries(self, node, liveins):
457466 for name , live_val in liveins .items ():
458467 if _is_triton_value (live_val ):
459468 loop_val = self .lscope [name ]
460- assert type ( live_val ) is type ( loop_val ), f'Loop carried variable { name } changed type'
469+ self . _verify_loop_carried_variable ( name , loop_val , live_val )
461470
462- live_handles = livehandles [ name ]
471+ live_handles = flatten_values_to_ir ([ live_val ])
463472 loop_handles = flatten_values_to_ir ([loop_val ])
464473 if live_handles != loop_handles :
465- self ._verify_loop_carried_variable (name , loop_val , live_val )
466-
467474 names .append (name )
468475 init_tys .append (live_val .type )
469476 init_handles .extend (live_handles )
0 commit comments