Skip to content

Commit 09130c6

Browse files
committed
loop invariance support complete
1 parent 4e1da7b commit 09130c6

File tree

3 files changed

+81
-15
lines changed

3 files changed

+81
-15
lines changed

src/bloqade/analysis/record/analysis.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ def add_record_idxs(self, num_new_records: int) -> list[RecordIdx]:
2323
# Return for usage, idxs linked to the global state
2424
return new_record_idxs
2525

26+
"""
27+
def clone_record_idxs(self, record_tuple: RecordTuple) -> RecordTuple:
28+
cloned_members = []
29+
for record_idx in record_tuple.members:
30+
cloned_record_idx = RecordIdx(record_idx.idx)
31+
# put into the global buffer but also
32+
# return an analysis-facing copy
33+
self.buffer.append(cloned_record_idx)
34+
cloned_members.append(cloned_record_idx)
35+
36+
return RecordTuple(members=tuple(cloned_members))
37+
"""
38+
2639
"""
2740
Might need a free after use! You can keep the size of the list small
2841
but could be a premature optimization...

src/bloqade/analysis/record/impls.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .lattice import (
1111
AnyRecord,
1212
NotRecord,
13+
RecordIdx,
1314
RecordTuple,
1415
InvalidRecord,
1516
ConstantCarrier,
@@ -31,6 +32,12 @@ def consumes_measurements(
3132
# Get the measurement results being consumed
3233
record_tuple_at_stmt = frame.get(stmt.measurements)
3334

35+
if not (
36+
isinstance(record_tuple_at_stmt, RecordTuple)
37+
and kirin_types.is_tuple_of(record_tuple_at_stmt.members, RecordIdx)
38+
):
39+
return (InvalidRecord(),)
40+
3441
final_record_idxs = [
3542
deepcopy(record_idx) for record_idx in record_tuple_at_stmt.members
3643
]
@@ -122,7 +129,7 @@ def alias(
122129
stmt: py.Alias,
123130
):
124131
input = frame.get(stmt.value) # expect this to be a RecordTuple
125-
132+
# frame.global_record_state.clone_record_idxs(input)
126133
# two variables share the same references in the global state
127134
return (input,)
128135

@@ -134,20 +141,70 @@ def for_loop(
134141
self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.For
135142
):
136143

137-
loop_vars = frame.get_values(stmt.initializers)
138-
139-
for _ in range(2):
140-
loop_vars = interp_.frame_call_region(
141-
frame, stmt, stmt.body, InvalidRecord(), *loop_vars
144+
init_loop_vars = frame.get_values(stmt.initializers)
145+
146+
# You go through the loops twice to verify the loop invariant.
147+
# we need to freeze the frame entries right after exiting the loop
148+
149+
first_loop_frame = RecordFrame(
150+
stmt,
151+
global_record_state=frame.global_record_state,
152+
parent=frame,
153+
has_parent_access=True,
154+
)
155+
first_loop_vars = interp_.frame_call_region(
156+
first_loop_frame, stmt, stmt.body, InvalidRecord(), *init_loop_vars
157+
)
158+
159+
if first_loop_vars is None:
160+
first_loop_vars = ()
161+
elif isinstance(first_loop_vars, interp.ReturnValue):
162+
return first_loop_vars
163+
164+
captured_first_loop_entries = {}
165+
captured_first_loop_vars = deepcopy(first_loop_vars)
166+
167+
for ssa_val, lattice_element in first_loop_frame.entries.items():
168+
captured_first_loop_entries[ssa_val] = deepcopy(lattice_element)
169+
170+
second_loop_frame = RecordFrame(
171+
stmt,
172+
global_record_state=frame.global_record_state,
173+
parent=frame,
174+
has_parent_access=True,
175+
)
176+
second_loop_vars = interp_.frame_call_region(
177+
second_loop_frame, stmt, stmt.body, InvalidRecord(), *first_loop_vars
178+
)
179+
180+
if second_loop_vars is None:
181+
second_loop_vars = ()
182+
elif isinstance(second_loop_vars, interp.ReturnValue):
183+
return second_loop_vars
184+
185+
# take the entries in the first and second loops
186+
# update the parent frame
187+
188+
unified_frame_buffer = {}
189+
for ssa_val, lattice_element in captured_first_loop_entries.items():
190+
verified_latticed_element = second_loop_frame.entries[ssa_val].join(
191+
lattice_element
142192
)
193+
# print(f"Joining {lattice_element} and {second_loop_frame.entries[ssa_val]} to get {verified_latticed_element}")
194+
unified_frame_buffer[ssa_val] = verified_latticed_element
195+
196+
frame.entries.update(unified_frame_buffer)
143197

144-
if loop_vars is None:
145-
loop_vars = ()
198+
if captured_first_loop_vars is None or second_loop_vars is None:
199+
return ()
146200

147-
elif isinstance(loop_vars, interp.ReturnValue):
148-
return loop_vars
201+
joined_loop_vars = []
202+
for first_loop_var, second_loop_var in zip(
203+
captured_first_loop_vars, second_loop_vars
204+
):
205+
joined_loop_vars.append(first_loop_var.join(second_loop_var))
149206

150-
return loop_vars
207+
return tuple(joined_loop_vars)
151208

152209
@interp.impl(scf.stmts.Yield)
153210
def for_yield(
@@ -156,8 +213,6 @@ def for_yield(
156213
return interp.YieldValue(frame.get_values(stmt.values))
157214

158215

159-
# Only carry about carrying integers for now because
160-
# the current issue is that
161216
@py.dialect.register(key="record")
162217
class ConstantForwarding(interp.MethodTable):
163218
@interp.impl(py.Constant)

test/analysis/record/test_record_analysis.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,11 @@ def test():
162162
squin.annotate.set_detector([prev_ms[0], curr_ms[0]], coordinates=[0, 0])
163163
squin.annotate.set_detector([prev_ms[1], curr_ms[1]], coordinates=[0, 1])
164164

165-
squin.annotate.set_detector(curr_ms, coordinates=[0, 0])
166165
data_ms = squin.broadcast.measure(data_qs)
167166

168167
squin.set_detector([data_ms[0], data_ms[1], curr_ms[0]], coordinates=[2, 0])
169168
squin.set_detector([data_ms[2], data_ms[1], curr_ms[1]], coordinates=[2, 1])
170169

171-
test.print()
172170
SoftFlatten(dialects=test.dialects).fixpoint(test)
173171
test.print()
174172
frame, _ = RecordAnalysis(dialects=test.dialects).run(test)

0 commit comments

Comments
 (0)