Skip to content

Commit dfdd6ea

Browse files
johnzl-777kaihsin
andcommitted
add alias support to measure_id analysis (#412)
@ehua7365 encountered a problem where reassigning an IList of measure results would cause `SquinToStim` to fail: ``` q = qubit.new(4) ms = qubit.measure(q) new_ms = ms if new_ms[0]: qubit.apply(op.z(), q[0]) ``` (This is an intentionally contrived example to keep things short). Originally I had directed Eric to include kirin's InlineAlias and InlineGetItem rules which also got rid of the issue until I took a closer look and realized the problem is because when an alias happens with an IList of measure results, the analysis does not propagate the corresponding data. As a result, further Stim-related rewrites have problems because the MeasureIDAnalysis data is not available for the subsequent SSA Values. --------- Co-authored-by: Kai-Hsin Wu <[email protected]>
1 parent 65ed26a commit dfdd6ea

File tree

5 files changed

+66
-0
lines changed

5 files changed

+66
-0
lines changed

src/bloqade/analysis/measure_id/impls.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,15 @@ def getitem(
114114
return (InvalidMeasureId(),)
115115

116116

117+
@py.assign.dialect.register(key="measure_id")
118+
class PyAssign(interp.MethodTable):
119+
@interp.impl(py.Alias)
120+
def alias(
121+
self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.assign.Alias
122+
):
123+
return (frame.get(stmt.value),)
124+
125+
117126
@py.binop.dialect.register(key="measure_id")
118127
class PyBinOp(interp.MethodTable):
119128
@interp.impl(py.Add)

src/bloqade/stim/passes/squin_to_stim.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from kirin.passes.abc import Pass
1717
from kirin.rewrite.abc import RewriteResult
1818
from kirin.passes.inline import InlinePass
19+
from kirin.rewrite.alias import InlineAlias
1920

2021
from bloqade.stim.rewrite import (
2122
SquinWireToStim,
@@ -89,6 +90,9 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
8990
rewrite_result = (
9091
Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
9192
)
93+
94+
Walk(InlineAlias()).rewrite(mt.code).join(rewrite_result)
95+
9296
rewrite_result = (
9397
StimSimplifyIfs(mt.dialects, no_raise=self.no_raise)
9498
.unsafe_run(mt)

test/analysis/measure_id/test_measure_id.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,37 @@ def test():
3939
assert measure_id_tuples[-1] == expected_measure_id_tuple
4040

4141

42+
def test_measure_alias():
43+
44+
@kernel
45+
def test():
46+
ql = qubit.new(5)
47+
ml = qubit.measure(ql)
48+
ml_alias = ml
49+
50+
return ml_alias
51+
52+
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
53+
54+
test.print(analysis=frame.entries)
55+
56+
# Collect MeasureIdTuples
57+
measure_id_tuples = [
58+
value for value in frame.entries.values() if isinstance(value, MeasureIdTuple)
59+
]
60+
61+
# construct expected MeasureIdTuple
62+
expected_measure_id_tuple = MeasureIdTuple(
63+
data=tuple([MeasureIdBool(idx=i) for i in range(1, 6)])
64+
)
65+
66+
assert len(measure_id_tuples) == 2
67+
assert all(
68+
measure_id_tuple == expected_measure_id_tuple
69+
for measure_id_tuple in measure_id_tuples
70+
)
71+
72+
4273
def test_measure_count_at_if_else():
4374

4475
@kernel
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
MZ(0.00000000) 0 1 2 3
3+
CZ rec[-4] 0

test/stim/passes/test_squin_meas_to_stim.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@ def main():
5353
assert base_stim_prog == codegen(main)
5454

5555

56+
def test_alias_with_measure_list():
57+
58+
@squin.kernel
59+
def main():
60+
61+
q = qubit.new(4)
62+
ms = qubit.measure(q)
63+
new_ms = ms
64+
65+
if new_ms[0]:
66+
qubit.apply(op.z(), q[0])
67+
68+
SquinToStimPass(main.dialects)(main)
69+
70+
base_stim_prog = load_reference_program("alias_with_measure_list.stim")
71+
72+
assert base_stim_prog == codegen(main)
73+
74+
5675
def test_record_index_order():
5776

5877
@squin.kernel

0 commit comments

Comments
 (0)