Skip to content

Commit 4920273

Browse files
committed
remove more record analysis debug prints and move everything into measurement_id
1 parent d4113e6 commit 4920273

File tree

4 files changed

+60
-13
lines changed

4 files changed

+60
-13
lines changed

src/bloqade/analysis/measure_id/analysis.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
from kirin.analysis import ForwardExtra
55
from kirin.analysis.forward import ForwardFrame
66

7-
from .lattice import MeasureId, NotMeasureId, KnownMeasureId
7+
from .lattice import MeasureId, NotMeasureId, KnownMeasureId, MeasureIdTuple
88

99

1010
@dataclass
1111
class GlobalRecordState:
1212
buffer: list[KnownMeasureId] = field(default_factory=list)
1313

1414
# assume that this KnownMeasureId will always be -1
15-
def add_record_idxs(self, num_new_records: int) -> list[KnownMeasureId]:
15+
def add_record_idxs(self, num_new_records: int) -> MeasureIdTuple:
1616
# adjust all previous indices
1717
for record_idx in self.buffer:
1818
record_idx.idx -= num_new_records
@@ -21,12 +21,33 @@ def add_record_idxs(self, num_new_records: int) -> list[KnownMeasureId]:
2121
new_record_idxs = [KnownMeasureId(-i) for i in range(num_new_records, 0, -1)]
2222
self.buffer += new_record_idxs
2323
# Return for usage, idxs linked to the global state
24-
return new_record_idxs
24+
return MeasureIdTuple(data=tuple(new_record_idxs))
25+
26+
# Need for loop invariance, especially when you
27+
# run the loop twice "behind the scenes". Then
28+
# it isn't sufficient to just have two
29+
# copies of a lattice element point to one entry on the
30+
# buffer
31+
def clone_record_idxs(self, measure_id_tuple: MeasureIdTuple) -> MeasureIdTuple:
32+
cloned_members = []
33+
for known_measure_id in measure_id_tuple.data:
34+
assert isinstance(known_measure_id, KnownMeasureId)
35+
cloned_known_measure_id = KnownMeasureId(known_measure_id.idx)
36+
# put into the global buffer but also
37+
# return an analysis-facing copy
38+
self.buffer.append(cloned_known_measure_id)
39+
cloned_members.append(cloned_known_measure_id)
40+
return MeasureIdTuple(data=tuple(cloned_members))
41+
42+
def offset_existing_records(self, offset: int):
43+
for record_idx in self.buffer:
44+
record_idx.idx -= offset
2545

2646

2747
@dataclass
2848
class MeasureIDFrame(ForwardFrame[MeasureId]):
2949
global_record_state: GlobalRecordState = field(default_factory=GlobalRecordState)
50+
measure_count_offset: int = 0
3051

3152

3253
class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]):

src/bloqade/analysis/measure_id/impls.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,15 @@ def measure_qubit_list(
3838
if not isinstance(num_qubits, kirin_types.Literal):
3939
return (AnyMeasureId(),)
4040

41-
record_idxs = frame.global_record_state.add_record_idxs(num_qubits.data)
41+
# increment the parent frame measure count offset.
42+
# Loop analysis relies on local state tracking
43+
# so we use this data after exiting a loop to
44+
# readjust the previous global measure count.
45+
frame.measure_count_offset += num_qubits.data
4246

43-
return (MeasureIdTuple(data=tuple(record_idxs)),)
47+
measure_id_tuple = frame.global_record_state.add_record_idxs(num_qubits.data)
48+
49+
return (measure_id_tuple,)
4450

4551

4652
@annotate.dialect.register(key="measure_id")
@@ -130,11 +136,16 @@ def getitem(
130136
class PyAssign(interp.MethodTable):
131137
@interp.impl(py.Alias)
132138
def alias(
133-
self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.assign.Alias
139+
self,
140+
interp: MeasurementIDAnalysis,
141+
frame: MeasureIDFrame,
142+
stmt: py.assign.Alias,
134143
):
135144

136145
input = frame.get(stmt.value)
137-
return (input,)
146+
147+
new_input = frame.global_record_state.clone_record_idxs(input)
148+
return (new_input,)
138149

139150

140151
@py.binop.dialect.register(key="measure_id")
@@ -183,9 +194,11 @@ def for_loop(
183194
# You go through the loops twice to verify the loop invariant.
184195
# we need to freeze the frame entries right after exiting the loop
185196

197+
local_state = deepcopy(frame.global_record_state)
198+
186199
first_loop_frame = MeasureIDFrame(
187200
stmt,
188-
global_record_state=frame.global_record_state,
201+
global_record_state=local_state,
189202
parent=frame,
190203
has_parent_access=True,
191204
)
@@ -206,7 +219,7 @@ def for_loop(
206219

207220
second_loop_frame = MeasureIDFrame(
208221
stmt,
209-
global_record_state=frame.global_record_state,
222+
global_record_state=local_state,
210223
parent=frame,
211224
has_parent_access=True,
212225
)
@@ -231,6 +244,9 @@ def for_loop(
231244
unified_frame_buffer[ssa_val] = verified_latticed_element
232245

233246
frame.entries.update(unified_frame_buffer)
247+
frame.global_record_state.offset_existing_records(
248+
first_loop_frame.measure_count_offset
249+
)
234250

235251
if captured_first_loop_vars is None or second_loop_vars is None:
236252
return ()
@@ -241,6 +257,20 @@ def for_loop(
241257
):
242258
joined_loop_vars.append(first_loop_var.join(second_loop_var))
243259

260+
# TrimYield is currently disabled meaning that the same RecordIdx
261+
# can get copied into the parent frame twice! As a result
262+
# we need to be careful to only add unique RecordIdx entries
263+
witnessed_record_idxs = set()
264+
for var in joined_loop_vars:
265+
if isinstance(var, MeasureIdTuple):
266+
for member in var.data:
267+
if (
268+
isinstance(member, KnownMeasureId)
269+
and member.idx not in witnessed_record_idxs
270+
):
271+
witnessed_record_idxs.add(member.idx)
272+
frame.global_record_state.buffer.append(member)
273+
244274
return tuple(joined_loop_vars)
245275

246276
@interp.impl(scf.stmts.Yield)

src/bloqade/analysis/record/analysis.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@ def clone_record_idxs(self, record_tuple: RecordTuple, id: int) -> RecordTuple:
4242
def offset_existing_records(self, offset: int):
4343
for record_idx in self.buffer:
4444
record_idx.idx -= offset
45-
print("offset is now:", offset)
46-
print("The record idx is now:", record_idx.idx)
47-
# print the record_idx after offsetting
4845

4946
"""
5047
Might need a free after use! You can keep the size of the list small

src/bloqade/analysis/record/impls.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,6 @@ def for_loop_double_pass(
245245
def for_yield(
246246
self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.Yield
247247
):
248-
print("yield encountered, yielding values:", frame.get_values(stmt.values))
249248
return interp.YieldValue(frame.get_values(stmt.values))
250249

251250

0 commit comments

Comments
 (0)