Skip to content

Commit 4e1da7b

Browse files
committed
almost there, still a problem with invariance checking
1 parent eee712d commit 4e1da7b

File tree

9 files changed

+251
-67
lines changed

9 files changed

+251
-67
lines changed
Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,35 @@
1-
from typing import TypeVar
21
from dataclasses import field, dataclass
32

43
from kirin import ir
5-
from kirin.analysis import ForwardExtra, const
4+
from kirin.analysis import ForwardExtra
65
from kirin.analysis.forward import ForwardFrame
76

87
from .lattice import Record, RecordIdx
98

109

1110
@dataclass
1211
class GlobalRecordState:
13-
stack: list[RecordIdx] = field(default_factory=list)
12+
buffer: list[RecordIdx] = field(default_factory=list)
1413

1514
# assume that this RecordIdx will always be -1
16-
def increment_record_idx(self) -> RecordIdx:
15+
def add_record_idxs(self, num_new_records: int) -> list[RecordIdx]:
1716
# adjust all previous indices
18-
for record_idx in self.stack:
19-
record_idx.idx -= 1
20-
self.stack.append(RecordIdx(-1))
21-
# Return for usage
22-
return self.stack[-1]
17+
for record_idx in self.buffer:
18+
record_idx.idx -= num_new_records
2319

24-
def drop_record_idx(self, record_to_drop: RecordIdx):
25-
# there is a chance now that the ordering is messed up but
26-
# we can now update the indices to enforce consistency.
27-
# We only have to update UP to the entry that was just removed
28-
# everything else maintains ordering
29-
dropped_idx = record_to_drop.idx
30-
self.stack.remove(record_to_drop)
31-
for record_idx in self.stack:
32-
if record_idx.idx < dropped_idx:
33-
record_idx.idx += 1
20+
# generate new indices and add them to the buffer
21+
new_record_idxs = [RecordIdx(-i) for i in range(num_new_records, 0, -1)]
22+
self.buffer += new_record_idxs
23+
# Return for usage, idxs linked to the global state
24+
return new_record_idxs
25+
26+
"""
27+
Might need a free after use! You can keep the size of the list small
28+
but could be a premature optimization...
29+
"""
30+
# def drop_record_idxs(self, record_tuple: RecordTuple):
31+
# for record_idx in record_tuple.members:
32+
# self.buffer.remove(record_idx)
3433

3534

3635
@dataclass
@@ -47,7 +46,7 @@ def initialize_frame(
4746
) -> RecordFrame:
4847
return RecordFrame(node, has_parent_access=has_parent_access)
4948

50-
def eval_stmt_fallback(
49+
def eval_fallback(
5150
self, frame: RecordFrame, node: ir.Statement
5251
) -> tuple[Record, ...]:
5352
return tuple(self.lattice.bottom() for _ in node.results)
@@ -56,17 +55,5 @@ def run_method(self, method, args: tuple[Record, ...]):
5655
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
5756
return self.run_method(method.code, (self.lattice.bottom(),) + args)
5857

59-
T = TypeVar("T")
60-
61-
def get_const_value(
62-
self, input_type: type[T], value: ir.SSAValue
63-
) -> type[T] | None:
64-
if isinstance(hint := value.hints.get("const"), const.Value):
65-
data = hint.data
66-
if isinstance(data, input_type):
67-
return hint.data
68-
69-
return None
70-
7158
def method_self(self, method: ir.Method) -> Record:
7259
return self.lattice.bottom()

src/bloqade/analysis/record/impls.py

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from copy import deepcopy
22

33
from kirin import types as kirin_types, interp
4+
from kirin.ir import PyAttr
45
from kirin.dialects import py, scf, ilist
56

67
from bloqade import qubit, annotate
@@ -9,9 +10,9 @@
910
from .lattice import (
1011
AnyRecord,
1112
NotRecord,
12-
RecordIdx,
1313
RecordTuple,
1414
InvalidRecord,
15+
ConstantCarrier,
1516
ImmutableRecords,
1617
)
1718
from .analysis import RecordFrame, RecordAnalysis
@@ -57,10 +58,7 @@ def measure_qubit_list(
5758
if not isinstance(num_qubits, kirin_types.Literal):
5859
return (AnyRecord(),)
5960

60-
record_idxs = []
61-
for _ in range(num_qubits.data):
62-
record_idx = frame.global_record_state.increment_record_idx()
63-
record_idxs.append(record_idx)
61+
record_idxs = frame.global_record_state.add_record_idxs(num_qubits.data)
6462

6563
return (RecordTuple(members=tuple(record_idxs)),)
6664

@@ -70,9 +68,22 @@ class PyIndexing(interp.MethodTable):
7068
@interp.impl(py.GetItem)
7169
def getitem(self, interp: RecordAnalysis, frame: RecordFrame, stmt: py.GetItem):
7270

73-
idx_or_slice = interp.get_const_value((int, slice), stmt.index)
74-
if idx_or_slice is None:
75-
return (InvalidRecord(),)
71+
# maybe_const will work fine outside of any loops because
72+
# constprop will put the expected data into a hint.
73+
74+
# if maybeconst fails, we fall back to getting the value from the frame
75+
# (note that even outside loops, the constant impl will happily
76+
# capture integer/slice constants so if THAT fails, then something
77+
# has truly gone wrong).
78+
possible_idx_or_slice = interp.maybe_const(stmt.index, (int, slice))
79+
if possible_idx_or_slice is not None:
80+
idx_or_slice = possible_idx_or_slice
81+
else:
82+
idx_or_slice = frame.get(stmt.index)
83+
if not isinstance(idx_or_slice, ConstantCarrier):
84+
return (InvalidRecord(),)
85+
else:
86+
idx_or_slice = idx_or_slice.value
7687

7788
obj = frame.get(stmt.obj)
7889
if isinstance(obj, RecordTuple):
@@ -106,36 +117,68 @@ class PyAlias(interp.MethodTable):
106117
@interp.impl(py.Alias)
107118
def alias(
108119
self,
109-
interp: RecordAnalysis,
120+
interp_: RecordAnalysis,
110121
frame: RecordFrame,
111122
stmt: py.Alias,
112123
):
113-
value = frame.get(stmt.value)
114-
if isinstance(value, RecordIdx):
115-
frame.global_record_state.drop_record_idx(value)
116-
elif isinstance(value, RecordTuple):
117-
for member in value.members:
118-
frame.global_record_state.drop_record_idx(member)
124+
input = frame.get(stmt.value) # expect this to be a RecordTuple
119125

120-
return (value,)
126+
# two variables share the same references in the global state
127+
return (input,)
121128

122129

123130
@scf.dialect.register(key="record")
124-
class LoopHandling(scf.absint.Methods):
131+
class LoopHandling(interp.MethodTable):
125132
@interp.impl(scf.stmts.For)
126133
def for_loop(
127134
self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.For
128135
):
129136

130-
# this will contain the in-loop measure variable declared outside the loop
131137
loop_vars = frame.get_values(stmt.initializers)
132-
# NotRecord in the beginning just lets the sink have some value
133-
loop_vars = interp_.run_ssacfg_region(frame, stmt.body, loop_vars)
134138

135-
# need to update the information in the frame
136-
if isinstance(loop_vars, interp.ReturnValue):
137-
return loop_vars
138-
elif loop_vars is None:
139-
loop_vars = ()
139+
for _ in range(2):
140+
loop_vars = interp_.frame_call_region(
141+
frame, stmt, stmt.body, InvalidRecord(), *loop_vars
142+
)
143+
144+
if loop_vars is None:
145+
loop_vars = ()
146+
147+
elif isinstance(loop_vars, interp.ReturnValue):
148+
return loop_vars
140149

141150
return loop_vars
151+
152+
@interp.impl(scf.stmts.Yield)
153+
def for_yield(
154+
self, interp_: RecordAnalysis, frame: RecordFrame, stmt: scf.stmts.Yield
155+
):
156+
return interp.YieldValue(frame.get_values(stmt.values))
157+
158+
159+
# Only carry about carrying integers for now because
160+
# the current issue is that
161+
@py.dialect.register(key="record")
162+
class ConstantForwarding(interp.MethodTable):
163+
@interp.impl(py.Constant)
164+
def constant(
165+
self,
166+
interp_: RecordAnalysis,
167+
frame: RecordFrame,
168+
stmt: py.Constant,
169+
):
170+
# can't use interp_.maybe_const/expect_const because it assumes the data is already
171+
# there to begin with...
172+
if not isinstance(stmt.value, PyAttr):
173+
return (InvalidRecord(),)
174+
175+
expected_int_or_slice = stmt.value.data
176+
177+
if not isinstance(expected_int_or_slice, (int, slice)):
178+
return (InvalidRecord(),)
179+
180+
return (ConstantCarrier(value=expected_int_or_slice),)
181+
182+
183+
# outside_frame -> create new frame with context manager COPIED from outside frame
184+
# the frame and the stack are separate

src/bloqade/analysis/record/lattice.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,19 @@ def is_subseteq(self, other: Record) -> bool:
5555
return isinstance(other, NotRecord)
5656

5757

58+
# For now I only care about propagating constant integers or slices,
59+
# things that can be used as indices to list of measurements
60+
@final
61+
@dataclass
62+
class ConstantCarrier(Record):
63+
value: int | slice
64+
65+
def is_subseteq(self, other: Record) -> bool:
66+
if isinstance(other, ConstantCarrier):
67+
return self.value == other.value
68+
return False
69+
70+
5871
@final
5972
@dataclass
6073
class RecordIdx(Record):

src/bloqade/stim/dialects/cf/__init__.py

Whitespace-only changes.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect("stim.cf")

src/bloqade/stim/dialects/cf/stmts.py

Whitespace-only changes.

src/bloqade/stim/passes/simplify_ifs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class StimSimplifyIfs(Pass):
2222
def unsafe_run(self, mt: ir.Method):
2323

2424
result = Chain(
25-
Walk(UnusedYield()),
25+
Walk(UnusedYield()), # this is being too aggressive, need to file an issue
2626
Walk(StimLiftThenBody()),
2727
# remove yields (if possible), then lift out as much stuff as possible
2828
Walk(DeadCodeElimination()),

src/bloqade/stim/passes/soft_flatten.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ class SoftFlatten(Pass):
8181

8282
def __post_init__(self):
8383
self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise)
84-
self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise)
84+
85+
# DO NOT USE FOR NOW, TrimUnusedYield call messes up loop structure
86+
# self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise)
8587

8688
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
8789
rewrite_result = RewriteResult()
88-
rewrite_result = self.simplify_if(mt).join(rewrite_result)
90+
# rewrite_result = self.simplify_if(mt).join(rewrite_result)
8991
rewrite_result = self.unroll(mt).join(rewrite_result)
9092
return rewrite_result

0 commit comments

Comments
 (0)