Skip to content

Commit e0b99db

Browse files
johnzl-777david-pl
andcommitted
fix nested list handling in SquinToStimPass (#405)
Fixes #404 and also caught a mistake I introduced which is that AggressiveUnroll does not `join` the rewrite_result chain. --------- Co-authored-by: David Plankensteiner <[email protected]>
1 parent 20576e8 commit e0b99db

File tree

4 files changed

+50
-7
lines changed

4 files changed

+50
-7
lines changed

src/bloqade/stim/passes/simplify_ifs.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
ConstantFold,
1010
CommonSubexpressionElimination,
1111
)
12+
from kirin.dialects.ilist.passes import ConstList2IList
1213

1314
from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
1415

@@ -23,8 +24,16 @@ def unsafe_run(self, mt: ir.Method):
2324
Walk(StimSplitIfStmts()),
2425
).rewrite(mt.code)
2526

27+
# because nested python lists don't have their
28+
# member lists converted to ILists, ConstantFold
29+
# can add python lists that can't be hashed, causing
30+
# issues with CSE. ConstList2IList remedies that problem here.
2631
result = (
27-
Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
32+
Chain(
33+
Fixpoint(Walk(ConstantFold())),
34+
Walk(ConstList2IList()),
35+
Walk(CommonSubexpressionElimination()),
36+
)
2837
.rewrite(mt.code)
2938
.join(result)
3039
)

src/bloqade/stim/passes/squin_to_stim.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22

3-
from kirin.passes import Fold, HintConst
3+
from kirin.passes import Fold, HintConst, TypeInfer
44
from kirin.rewrite import (
55
Walk,
66
Chain,
@@ -33,6 +33,7 @@
3333
from bloqade.rewrite.passes import CanonicalizeIList
3434
from bloqade.analysis.address import AddressAnalysis
3535
from bloqade.analysis.measure_id import MeasurementIDAnalysis
36+
from bloqade.squin.rewrite.desugar import ApplyDesugarRule
3637

3738
from .simplify_ifs import StimSimplifyIfs
3839
from ..rewrite.ifs_to_stim import IfToStim
@@ -79,9 +80,11 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
7980
dialects=mt.dialects, no_raise=self.no_raise
8081
).unsafe_run(mt)
8182

82-
rewrite_result = AggressiveForLoopUnroll(
83-
dialects=mt.dialects, no_raise=self.no_raise
84-
).fixpoint(mt)
83+
rewrite_result = (
84+
AggressiveForLoopUnroll(dialects=mt.dialects, no_raise=self.no_raise)
85+
.fixpoint(mt)
86+
.join(rewrite_result)
87+
)
8588

8689
rewrite_result = (
8790
Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
@@ -105,6 +108,9 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
105108
.join(rewrite_result)
106109
)
107110

111+
TypeInfer(dialects=mt.dialects, no_raise=self.no_raise).unsafe_run(mt)
112+
Walk(ApplyDesugarRule()).rewrite(mt.code)
113+
108114
# after this the program should be in a state where it is analyzable
109115
# -------------------------------------------------------------------
110116

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
H 0
3+
H 2

test/stim/passes/test_squin_qubit_to_stim.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def main():
161161
squin.qubit.apply(squin.op.h(), q[0])
162162
cx = squin.op.cx()
163163
for i in range(2):
164-
squin.qubit.apply(cx, [q[i], q[i + 1]])
164+
squin.qubit.apply(cx, q[i], q[i + 1])
165165

166166
SquinToStimPass(main.dialects)(main)
167167
base_stim_prog = load_reference_program("for_loop_nontrivial_index.stim")
@@ -178,9 +178,34 @@ def main():
178178
cx = squin.op.cx()
179179
for i in range(2):
180180
for j in range(2, 3):
181-
squin.qubit.apply(cx, [q[i], q[j]])
181+
squin.qubit.apply(cx, q[i], q[j])
182182

183183
SquinToStimPass(main.dialects)(main)
184184
base_stim_prog = load_reference_program("nested_for_loop.stim")
185185

186186
assert codegen(main) == base_stim_prog.rstrip()
187+
188+
189+
def test_nested_list():
190+
191+
# NOTE: While SquinToStim now has the ability to handle
192+
# the nested list outside of the kernel in this test,
193+
# in general it will be necessary to explicitly
194+
# annotate it as an IList so type inference can work
195+
# properly. Otherwise its global, mutable nature means
196+
# we cannot assume a static type.
197+
198+
pairs = [[0, 1], [2, 3]]
199+
200+
@squin.kernel
201+
def main():
202+
q = qubit.new(10)
203+
h = squin.op.h()
204+
for i in range(2):
205+
squin.qubit.apply(h, q[pairs[i][0]])
206+
207+
SquinToStimPass(main.dialects)(main)
208+
209+
base_stim_prog = load_reference_program("nested_list.stim")
210+
211+
assert codegen(main) == base_stim_prog.rstrip()

0 commit comments

Comments
 (0)