Skip to content

Commit b4c4559

Browse files
committed
still need to find a solution to proper IfElse handling
1 parent 8c5f453 commit b4c4559

File tree

6 files changed

+71
-38
lines changed

6 files changed

+71
-38
lines changed

src/bloqade/analysis/measure_id/impls.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
MeasureIdTuple,
1515
ConstantCarrier,
1616
InvalidMeasureId,
17-
ImmutableMeasureIds,
1817
PredicatedMeasureId,
1918
)
2019
from .analysis import MeasureIDFrame, MeasurementIDAnalysis
@@ -80,7 +79,7 @@ def measurement_predicate(
8079
PredicatedMeasureId(measure_id.idx, predicate)
8180
for measure_id in original_measure_id_tuple.data
8281
]
83-
return (ImmutableMeasureIds(data=tuple(predicate_measure_ids)),)
82+
return (MeasureIdTuple(data=tuple(predicate_measure_ids)),)
8483

8584

8685
@gemini.logical.dialect.register(key="measure_id")
@@ -107,7 +106,7 @@ def terminal_measurement(
107106
# Immutable usually desired for stim generation
108107
# but we can reuse it here to indicate
109108
# the measurement ids should not change anymore.
110-
return (ImmutableMeasureIds(data=tuple(measure_id_bools)),)
109+
return (MeasureIdTuple(data=tuple(measure_id_bools), immutable=True),)
111110

112111

113112
@annotate.dialect.register(key="measure_id")
@@ -124,17 +123,15 @@ def consumes_measurements(
124123

125124
if not (
126125
isinstance(measure_id_tuple_at_stmt, MeasureIdTuple)
127-
and kirin_types.is_tuple_of(
128-
measure_id_tuple_at_stmt.data, PredicatedMeasureId
129-
)
126+
and kirin_types.is_tuple_of(measure_id_tuple_at_stmt.data, RawMeasureId)
130127
):
131128
return (InvalidMeasureId(),)
132129

133130
final_record_idxs = [
134131
deepcopy(record_idx) for record_idx in measure_id_tuple_at_stmt.data
135132
]
136133

137-
return (ImmutableMeasureIds(data=tuple(final_record_idxs)),)
134+
return (MeasureIdTuple(data=tuple(final_record_idxs), immutable=True),)
138135

139136

140137
@ilist.dialect.register(key="measure_id")
@@ -150,7 +147,7 @@ def new_ilist(
150147
stmt: ilist.New,
151148
):
152149

153-
return (MeasureIdTuple(frame.get_values(stmt.values)),)
150+
return (MeasureIdTuple(data=frame.get_values(stmt.values)),)
154151

155152

156153
@py.tuple.dialect.register(key="measure_id")
@@ -345,6 +342,21 @@ def for_yield(
345342
):
346343
return interp.YieldValue(frame.get_values(stmt.values))
347344

345+
@interp.impl(scf.stmts.IfElse)
346+
def if_else(
347+
self,
348+
interp_: MeasurementIDAnalysis,
349+
frame: MeasureIDFrame,
350+
stmt: scf.stmts.IfElse,
351+
):
352+
cond_measure_id = frame.get(stmt.cond)
353+
assert type(cond_measure_id) is PredicatedMeasureId
354+
detached_cond_measure_id = PredicatedMeasureId(
355+
idx=deepcopy(cond_measure_id.idx), predicate=cond_measure_id.predicate
356+
)
357+
# remove underlying reference to the frame
358+
frame.set(stmt.cond, detached_cond_measure_id)
359+
348360

349361
@py.dialect.register(key="measure_id")
350362
class ConstantForwarding(interp.MethodTable):

src/bloqade/analysis/measure_id/lattice.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,27 +93,14 @@ def is_subseteq(self, other: MeasureId) -> bool:
9393
@dataclass
9494
class MeasureIdTuple(MeasureId):
9595
data: tuple[MeasureId, ...]
96+
immutable: bool = False
9697

9798
def is_subseteq(self, other: MeasureId) -> bool:
9899
if isinstance(other, MeasureIdTuple):
99100
return all(a.is_subseteq(b) for a, b in zip(self.data, other.data))
100101
return False
101102

102103

103-
@final
104-
@dataclass
105-
class ImmutableMeasureIds(MeasureId):
106-
# SetDetector happily consumes RawMeasureIds, but
107-
# for scf.IfElse rewrite with predicates I need to allow
108-
# PredicatedMeasureIds as well.
109-
data: tuple[PredicatedMeasureId | RawMeasureId, ...]
110-
111-
def is_subseteq(self, other: MeasureId) -> bool:
112-
if isinstance(other, ImmutableMeasureIds):
113-
return all(a.is_subseteq(b) for a, b in zip(self.data, other.data))
114-
return False
115-
116-
117104
# For now I only care about propagating constant integers or slices,
118105
# things that can be used as indices to list of measurements
119106
@final

src/bloqade/stim/passes/soft_flatten.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,10 @@ class SoftFlatten(Pass):
8181

8282
def __post_init__(self):
8383
self.unroll = AggressiveUnroll(self.dialects, no_raise=self.no_raise)
84-
85-
# DO NOT USE FOR NOW, TrimUnusedYield call messes up loop structure
86-
# self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise)
84+
self.simplify_if = StimSimplifyIfs(self.dialects, no_raise=self.no_raise)
8785

8886
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
8987
rewrite_result = RewriteResult()
90-
# rewrite_result = self.simplify_if(mt).join(rewrite_result)
88+
rewrite_result = self.simplify_if(mt).join(rewrite_result)
9189
rewrite_result = self.unroll(mt).join(rewrite_result)
9290
return rewrite_result

src/bloqade/stim/rewrite/get_record_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.dialects import py
33

44
from bloqade.stim.dialects import auxiliary
5-
from bloqade.analysis.measure_id.lattice import MeasureIdTuple, PredicatedMeasureId
5+
from bloqade.analysis.measure_id.lattice import RawMeasureId, MeasureIdTuple
66

77

88
def insert_get_records(node: ir.Statement, measure_id_tuple: MeasureIdTuple):
@@ -11,7 +11,7 @@ def insert_get_records(node: ir.Statement, measure_id_tuple: MeasureIdTuple):
1111
"""
1212
get_record_ssas = []
1313
for known_measure_id in measure_id_tuple.data:
14-
assert isinstance(known_measure_id, PredicatedMeasureId)
14+
assert isinstance(known_measure_id, RawMeasureId)
1515
idx_stmt = py.constant.Constant(known_measure_id.idx)
1616
idx_stmt.insert_before(node)
1717
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)

test/analysis/measure_id/test_measure_id.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from bloqade import squin, gemini
55
from bloqade.analysis.measure_id import MeasurementIDAnalysis
66
from bloqade.stim.passes.flatten import Flatten
7+
8+
# from bloqade.stim.passes.soft_flatten import SoftFlatten
9+
from bloqade.stim.passes.squin_to_stim import SquinToStimPass
710
from bloqade.analysis.measure_id.lattice import (
811
Predicate,
912
NotMeasureId,
1013
RawMeasureId,
1114
MeasureIdTuple,
1215
InvalidMeasureId,
13-
ImmutableMeasureIds,
1416
PredicatedMeasureId,
1517
)
1618

@@ -302,29 +304,29 @@ def test():
302304
test, ("is_zero_bools", "is_one_bools", "is_lost_bools")
303305
)
304306

305-
expected_is_zero_bools = ImmutableMeasureIds(
307+
expected_is_zero_bools = MeasureIdTuple(
306308
data=tuple(
307309
[
308310
PredicatedMeasureId(idx=i, predicate=Predicate.IS_ZERO)
309311
for i in range(-3, 0)
310312
]
311-
)
313+
),
312314
)
313-
expected_is_one_bools = ImmutableMeasureIds(
315+
expected_is_one_bools = MeasureIdTuple(
314316
data=tuple(
315317
[
316318
PredicatedMeasureId(idx=i, predicate=Predicate.IS_ONE)
317319
for i in range(-3, 0)
318320
]
319-
)
321+
),
320322
)
321-
expected_is_lost_bools = ImmutableMeasureIds(
323+
expected_is_lost_bools = MeasureIdTuple(
322324
data=tuple(
323325
[
324326
PredicatedMeasureId(idx=i, predicate=Predicate.IS_LOST)
325327
for i in range(-3, 0)
326328
]
327-
)
329+
),
328330
)
329331

330332
assert frame.get(results["is_zero_bools"]) == expected_is_zero_bools
@@ -347,9 +349,38 @@ def tm_logical_kernel():
347349
# basically a container of InvalidMeasureIds from the qubits that get allocated
348350
tm_logical_kernel.print(analysis=frame.entries)
349351
analysis_results = [
350-
val for val in frame.entries.values() if isinstance(val, ImmutableMeasureIds)
352+
val for val in frame.entries.values() if isinstance(val, MeasureIdTuple)
351353
]
352-
expected_result = ImmutableMeasureIds(
353-
data=tuple([RawMeasureId(idx=-i) for i in range(1, 4)])
354+
expected_result = MeasureIdTuple(
355+
data=tuple([RawMeasureId(idx=-i) for i in range(1, 4)]),
356+
immutable=True,
354357
)
355358
assert expected_result in analysis_results
359+
360+
361+
def test_if_else_happy_path():
362+
363+
@squin.kernel
364+
def test():
365+
qs = squin.qalloc(3)
366+
ms = squin.broadcast.measure(qs)
367+
# predicate
368+
pred_ms = squin.broadcast.is_one(ms)
369+
squin.broadcast.measure(qs)
370+
squin.broadcast.measure(qs)
371+
if pred_ms[0]:
372+
squin.x(qs[1])
373+
374+
return
375+
376+
# Flatten(test.dialects).fixpoint(test)
377+
# SoftFlatten(test.dialects).fixpoint(test)
378+
test.print()
379+
SquinToStimPass(test.dialects)(test)
380+
test.print()
381+
# test.print()
382+
# frame, _ = MeasurementIDAnalysis(test.dialects).run(test)
383+
# test.print(analysis=frame.entries)
384+
385+
386+
test_if_else_happy_path()

test/stim/passes/test_annotation_to_stim.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,15 @@ def main():
148148

149149
return
150150

151+
main.print()
151152
SquinToStimPass(main.dialects, no_raise=True)(main)
153+
main.print()
152154
assert any(isinstance(stmt, scf.IfElse) for stmt in main.code.regions[0].stmts())
153155

154156

157+
test_missing_predicate()
158+
159+
155160
def test_incorrect_predicate():
156161

157162
# You can only rewrite squin.is_one(...) predicates to

0 commit comments

Comments
 (0)