@@ -38,9 +38,15 @@ def measure_qubit_list(
3838 if not isinstance (num_qubits , kirin_types .Literal ):
3939 return (AnyMeasureId (),)
4040
41- record_idxs = frame .global_record_state .add_record_idxs (num_qubits .data )
41+ # increment the parent frame measure count offset.
42+ # Loop analysis relies on local state tracking
43+ # so we use this data after exiting a loop to
44+ # readjust the previous global measure count.
45+ frame .measure_count_offset += num_qubits .data
4246
43- return (MeasureIdTuple (data = tuple (record_idxs )),)
47+ measure_id_tuple = frame .global_record_state .add_record_idxs (num_qubits .data )
48+
49+ return (measure_id_tuple ,)
4450
4551
4652@annotate .dialect .register (key = "measure_id" )
@@ -130,11 +136,16 @@ def getitem(
130136class PyAssign (interp .MethodTable ):
131137 @interp .impl (py .Alias )
132138 def alias (
133- self , interp : MeasurementIDAnalysis , frame : interp .Frame , stmt : py .assign .Alias
139+ self ,
140+ interp : MeasurementIDAnalysis ,
141+ frame : MeasureIDFrame ,
142+ stmt : py .assign .Alias ,
134143 ):
135144
136145 input = frame .get (stmt .value )
137- return (input ,)
146+
147+ new_input = frame .global_record_state .clone_record_idxs (input )
148+ return (new_input ,)
138149
139150
140151@py .binop .dialect .register (key = "measure_id" )
@@ -183,9 +194,11 @@ def for_loop(
183194 # You go through the loops twice to verify the loop invariant.
184195 # we need to freeze the frame entries right after exiting the loop
185196
197+ local_state = deepcopy (frame .global_record_state )
198+
186199 first_loop_frame = MeasureIDFrame (
187200 stmt ,
188- global_record_state = frame . global_record_state ,
201+ global_record_state = local_state ,
189202 parent = frame ,
190203 has_parent_access = True ,
191204 )
@@ -206,7 +219,7 @@ def for_loop(
206219
207220 second_loop_frame = MeasureIDFrame (
208221 stmt ,
209- global_record_state = frame . global_record_state ,
222+ global_record_state = local_state ,
210223 parent = frame ,
211224 has_parent_access = True ,
212225 )
@@ -231,6 +244,9 @@ def for_loop(
231244 unified_frame_buffer [ssa_val ] = verified_latticed_element
232245
233246 frame .entries .update (unified_frame_buffer )
247+ frame .global_record_state .offset_existing_records (
248+ first_loop_frame .measure_count_offset
249+ )
234250
235251 if captured_first_loop_vars is None or second_loop_vars is None :
236252 return ()
@@ -241,6 +257,20 @@ def for_loop(
241257 ):
242258 joined_loop_vars .append (first_loop_var .join (second_loop_var ))
243259
260+ # TrimYield is currently disabled meaning that the same RecordIdx
261+ # can get copied into the parent frame twice! As a result
262+ # we need to be careful to only add unique RecordIdx entries
263+ witnessed_record_idxs = set ()
264+ for var in joined_loop_vars :
265+ if isinstance (var , MeasureIdTuple ):
266+ for member in var .data :
267+ if (
268+ isinstance (member , KnownMeasureId )
269+ and member .idx not in witnessed_record_idxs
270+ ):
271+ witnessed_record_idxs .add (member .idx )
272+ frame .global_record_state .buffer .append (member )
273+
244274 return tuple (joined_loop_vars )
245275
246276 @interp .impl (scf .stmts .Yield )
0 commit comments