1010from .lattice import (
1111 AnyRecord ,
1212 NotRecord ,
13+ RecordIdx ,
1314 RecordTuple ,
1415 InvalidRecord ,
1516 ConstantCarrier ,
@@ -31,6 +32,12 @@ def consumes_measurements(
3132 # Get the measurement results being consumed
3233 record_tuple_at_stmt = frame .get (stmt .measurements )
3334
35+ if not (
36+ isinstance (record_tuple_at_stmt , RecordTuple )
37+ and kirin_types .is_tuple_of (record_tuple_at_stmt .members , RecordIdx )
38+ ):
39+ return (InvalidRecord (),)
40+
3441 final_record_idxs = [
3542 deepcopy (record_idx ) for record_idx in record_tuple_at_stmt .members
3643 ]
@@ -122,7 +129,7 @@ def alias(
122129 stmt : py .Alias ,
123130 ):
124131 input = frame .get (stmt .value ) # expect this to be a RecordTuple
125-
132+ # frame.global_record_state.clone_record_idxs(input)
126133 # two variables share the same references in the global state
127134 return (input ,)
128135
@@ -134,20 +141,70 @@ def for_loop(
134141 self , interp_ : RecordAnalysis , frame : RecordFrame , stmt : scf .stmts .For
135142 ):
136143
137- loop_vars = frame .get_values (stmt .initializers )
138-
139- for _ in range (2 ):
140- loop_vars = interp_ .frame_call_region (
141- frame , stmt , stmt .body , InvalidRecord (), * loop_vars
144+ init_loop_vars = frame .get_values (stmt .initializers )
145+
146+ # You go through the loops twice to verify the loop invariant.
147+ # we need to freeze the frame entries right after exiting the loop
148+
149+ first_loop_frame = RecordFrame (
150+ stmt ,
151+ global_record_state = frame .global_record_state ,
152+ parent = frame ,
153+ has_parent_access = True ,
154+ )
155+ first_loop_vars = interp_ .frame_call_region (
156+ first_loop_frame , stmt , stmt .body , InvalidRecord (), * init_loop_vars
157+ )
158+
159+ if first_loop_vars is None :
160+ first_loop_vars = ()
161+ elif isinstance (first_loop_vars , interp .ReturnValue ):
162+ return first_loop_vars
163+
164+ captured_first_loop_entries = {}
165+ captured_first_loop_vars = deepcopy (first_loop_vars )
166+
167+ for ssa_val , lattice_element in first_loop_frame .entries .items ():
168+ captured_first_loop_entries [ssa_val ] = deepcopy (lattice_element )
169+
170+ second_loop_frame = RecordFrame (
171+ stmt ,
172+ global_record_state = frame .global_record_state ,
173+ parent = frame ,
174+ has_parent_access = True ,
175+ )
176+ second_loop_vars = interp_ .frame_call_region (
177+ second_loop_frame , stmt , stmt .body , InvalidRecord (), * first_loop_vars
178+ )
179+
180+ if second_loop_vars is None :
181+ second_loop_vars = ()
182+ elif isinstance (second_loop_vars , interp .ReturnValue ):
183+ return second_loop_vars
184+
185+ # take the entries in the first and second loops
186+ # update the parent frame
187+
188+ unified_frame_buffer = {}
189+ for ssa_val , lattice_element in captured_first_loop_entries .items ():
190+ verified_latticed_element = second_loop_frame .entries [ssa_val ].join (
191+ lattice_element
142192 )
193+ # print(f"Joining {lattice_element} and {second_loop_frame.entries[ssa_val]} to get {verified_latticed_element}")
194+ unified_frame_buffer [ssa_val ] = verified_latticed_element
195+
196+ frame .entries .update (unified_frame_buffer )
143197
144- if loop_vars is None :
145- loop_vars = ()
198+ if captured_first_loop_vars is None or second_loop_vars is None :
199+ return ()
146200
147- elif isinstance (loop_vars , interp .ReturnValue ):
148- return loop_vars
201+ joined_loop_vars = []
202+ for first_loop_var , second_loop_var in zip (
203+ captured_first_loop_vars , second_loop_vars
204+ ):
205+ joined_loop_vars .append (first_loop_var .join (second_loop_var ))
149206
150- return loop_vars
207+ return tuple ( joined_loop_vars )
151208
152209 @interp .impl (scf .stmts .Yield )
153210 def for_yield (
@@ -156,8 +213,6 @@ def for_yield(
156213 return interp .YieldValue (frame .get_values (stmt .values ))
157214
158215
159- # Only carry about carrying integers for now because
160- # the current issue is that
161216@py .dialect .register (key = "record" )
162217class ConstantForwarding (interp .MethodTable ):
163218 @interp .impl (py .Constant )
0 commit comments