Skip to content

Commit a39a126

Browse files
committed
latest attempt to try to reconcile type lattices
1 parent 79880d4 commit a39a126

File tree

9 files changed

+96
-111
lines changed

9 files changed

+96
-111
lines changed

src/bloqade/analysis/measure_id/analysis.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44
from kirin.analysis import ForwardExtra
55
from kirin.analysis.forward import ForwardFrame
66

7-
from .lattice import MeasureId, NotMeasureId, KnownMeasureId, MeasureIdTuple
7+
from .lattice import (
8+
MeasureId,
9+
NotMeasureId,
10+
RawMeasureId,
11+
MeasureIdTuple,
12+
)
813

914

1015
@dataclass
1116
class GlobalRecordState:
12-
buffer: list[KnownMeasureId] = field(default_factory=list)
17+
buffer: list[RawMeasureId] = field(default_factory=list)
1318

1419
# assume that this KnownMeasureId will always be -1
1520
def add_record_idxs(self, num_new_records: int) -> MeasureIdTuple:
@@ -18,7 +23,7 @@ def add_record_idxs(self, num_new_records: int) -> MeasureIdTuple:
1823
record_idx.idx -= num_new_records
1924

2025
# generate new indices and add them to the buffer
21-
new_record_idxs = [KnownMeasureId(-i) for i in range(num_new_records, 0, -1)]
26+
new_record_idxs = [RawMeasureId(-i) for i in range(num_new_records, 0, -1)]
2227
self.buffer += new_record_idxs
2328
# Return for usage, idxs linked to the global state
2429
return MeasureIdTuple(data=tuple(new_record_idxs))
@@ -30,13 +35,13 @@ def add_record_idxs(self, num_new_records: int) -> MeasureIdTuple:
3035
# buffer
3136
def clone_record_idxs(self, measure_id_tuple: MeasureIdTuple) -> MeasureIdTuple:
3237
cloned_members = []
33-
for known_measure_id in measure_id_tuple.data:
34-
assert isinstance(known_measure_id, KnownMeasureId)
35-
cloned_known_measure_id = KnownMeasureId(known_measure_id.idx)
38+
for raw_measure_id in measure_id_tuple.data:
39+
assert isinstance(raw_measure_id, RawMeasureId)
40+
cloned_raw_measure_id = RawMeasureId(raw_measure_id.idx)
3641
# put into the global buffer but also
3742
# return an analysis-facing copy
38-
self.buffer.append(cloned_known_measure_id)
39-
cloned_members.append(cloned_known_measure_id)
43+
self.buffer.append(cloned_raw_measure_id)
44+
cloned_members.append(cloned_raw_measure_id)
4045
return MeasureIdTuple(data=tuple(cloned_members))
4146

4247
def offset_existing_records(self, offset: int):

src/bloqade/analysis/measure_id/impls.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
Predicate,
1111
AnyMeasureId,
1212
NotMeasureId,
13-
KnownMeasureId,
13+
RawMeasureId,
1414
MeasureIdTuple,
1515
ConstantCarrier,
1616
InvalidMeasureId,
1717
ImmutableMeasureIds,
18+
PredicatedMeasureId,
1819
)
1920
from .analysis import MeasureIDFrame, MeasurementIDAnalysis
2021

@@ -61,7 +62,7 @@ def measurement_predicate(
6162
):
6263
original_measure_id_tuple = frame.get(stmt.measurements)
6364
if not all(
64-
isinstance(measure_id, KnownMeasureId)
65+
isinstance(measure_id, RawMeasureId)
6566
for measure_id in original_measure_id_tuple.data
6667
):
6768
return (InvalidMeasureId(),)
@@ -76,10 +77,10 @@ def measurement_predicate(
7677
return (InvalidMeasureId(),)
7778

7879
predicate_measure_ids = [
79-
KnownMeasureId(measure_id.idx, predicate)
80+
PredicatedMeasureId(measure_id.idx, predicate)
8081
for measure_id in original_measure_id_tuple.data
8182
]
82-
return (MeasureIdTuple(data=tuple(predicate_measure_ids)),)
83+
return (ImmutableMeasureIds(data=tuple(predicate_measure_ids)),)
8384

8485

8586
@gemini.logical.dialect.register(key="measure_id")
@@ -100,11 +101,13 @@ def terminal_measurement(
100101
return (AnyMeasureId(),)
101102

102103
measure_id_bools = []
103-
for _ in range(num_qubits.data):
104-
interp.measure_count += 1
105-
measure_id_bools.append(RawMeasureId(interp.measure_count))
104+
for i in range(num_qubits.data):
105+
measure_id_bools.append(RawMeasureId(idx=-(i + 1)))
106106

107-
return (MeasureIdTuple(data=tuple(measure_id_bools)),)
107+
# Immutable usually desired for stim generation
108+
# but we can reuse it here to indicate
109+
# the measurement ids should not change anymore.
110+
return (ImmutableMeasureIds(data=tuple(measure_id_bools)),)
108111

109112

110113
@annotate.dialect.register(key="measure_id")
@@ -121,7 +124,9 @@ def consumes_measurements(
121124

122125
if not (
123126
isinstance(measure_id_tuple_at_stmt, MeasureIdTuple)
124-
and kirin_types.is_tuple_of(measure_id_tuple_at_stmt.data, KnownMeasureId)
127+
and kirin_types.is_tuple_of(
128+
measure_id_tuple_at_stmt.data, PredicatedMeasureId
129+
)
125130
):
126131
return (InvalidMeasureId(),)
127132

@@ -241,7 +246,7 @@ def invoke(
241246

242247

243248
@scf.dialect.register(key="measure_id")
244-
class LoopHandling(interp.MethodTable):
249+
class ScfHandling(interp.MethodTable):
245250
@interp.impl(scf.stmts.For)
246251
def for_loop(
247252
self, interp_: MeasurementIDAnalysis, frame: MeasureIDFrame, stmt: scf.stmts.For
@@ -323,7 +328,7 @@ def for_loop(
323328
if isinstance(var, MeasureIdTuple):
324329
for member in var.data:
325330
if (
326-
isinstance(member, KnownMeasureId)
331+
isinstance(member, RawMeasureId)
327332
and member.idx not in witnessed_record_idxs
328333
):
329334
witnessed_record_idxs.add(member.idx)

src/bloqade/analysis/measure_id/lattice.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,23 @@ def is_subseteq(self, other: MeasureId) -> bool:
6868

6969
@final
7070
@dataclass
71-
class KnownMeasureId(MeasureId):
71+
class RawMeasureId(MeasureId):
72+
idx: int
73+
74+
def is_subseteq(self, other: MeasureId) -> bool:
75+
if isinstance(other, RawMeasureId):
76+
return self.idx == other.idx
77+
return False
78+
79+
80+
@final
81+
@dataclass
82+
class PredicatedMeasureId(MeasureId):
7283
idx: int
7384
predicate: Predicate
7485

7586
def is_subseteq(self, other: MeasureId) -> bool:
76-
if isinstance(other, KnownMeasureId):
87+
if isinstance(other, PredicatedMeasureId):
7788
return self.idx == other.idx and self.predicate == other.predicate
7889
return False
7990

@@ -92,7 +103,10 @@ def is_subseteq(self, other: MeasureId) -> bool:
92103
@final
93104
@dataclass
94105
class ImmutableMeasureIds(MeasureId):
95-
data: tuple[KnownMeasureId, ...]
106+
# SetDetector happily consumes RawMeasureIds, but
107+
# for scf.IfElse rewrite with predicates I need to allow
108+
# PredicatedMeasureIds as well.
109+
data: tuple[PredicatedMeasureId | RawMeasureId, ...]
96110

97111
def is_subseteq(self, other: MeasureId) -> bool:
98112
if isinstance(other, ImmutableMeasureIds):

src/bloqade/stim/passes/squin_to_stim.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim
3131

32+
# from bloqade.stim.passes.soft_flatten import SoftFlatten
33+
3234

3335
@dataclass
3436
class SquinToStimPass(Pass):

src/bloqade/stim/rewrite/get_record_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.dialects import py
33

44
from bloqade.stim.dialects import auxiliary
5-
from bloqade.analysis.measure_id.lattice import KnownMeasureId, MeasureIdTuple
5+
from bloqade.analysis.measure_id.lattice import MeasureIdTuple, PredicatedMeasureId
66

77

88
def insert_get_records(node: ir.Statement, measure_id_tuple: MeasureIdTuple):
@@ -11,7 +11,7 @@ def insert_get_records(node: ir.Statement, measure_id_tuple: MeasureIdTuple):
1111
"""
1212
get_record_ssas = []
1313
for known_measure_id in measure_id_tuple.data:
14-
assert isinstance(known_measure_id, KnownMeasureId)
14+
assert isinstance(known_measure_id, PredicatedMeasureId)
1515
idx_stmt = py.constant.Constant(known_measure_id.idx)
1616
idx_stmt.insert_before(node)
1717
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)

src/bloqade/stim/rewrite/ifs_to_stim.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ
1414
from bloqade.analysis.measure_id import MeasureIDFrame
1515
from bloqade.stim.dialects.auxiliary import GetRecord
16-
from bloqade.analysis.measure_id.lattice import Predicate, KnownMeasureId
16+
from bloqade.analysis.measure_id.lattice import Predicate, PredicatedMeasureId
1717

1818

1919
@dataclass
@@ -140,7 +140,7 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
140140
condition_type = self.measure_frame.entries.get(stmt.cond)
141141
# Check the condition is a singular MeasurementIdBool and that
142142
# it was generated by querying if the measurement is equivalent to the one state
143-
if not isinstance(condition_type, KnownMeasureId):
143+
if not isinstance(condition_type, PredicatedMeasureId):
144144
return RewriteResult()
145145

146146
if condition_type.predicate != Predicate.IS_ONE:
@@ -162,12 +162,8 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
162162
return RewriteResult()
163163

164164
# generate get record statement
165-
num_measures = self.measure_frame.num_measures_at_stmt.get(stmt)
166-
if num_measures is None:
167-
return RewriteResult()
168-
169-
measure_id_idx_stmt = py.Constant((condition_type.idx - 1) - num_measures)
170-
get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841
165+
measure_id_idx_stmt = py.Constant(condition_type.idx)
166+
get_record_stmt = GetRecord(id=measure_id_idx_stmt.result)
171167

172168
address_attr = stmts[0].qubits.hints.get("address")
173169

0 commit comments

Comments
 (0)