Skip to content

Commit 680a4f9

Browse files
johnzl-777david-pl
andcommitted
Fix nested for-loop handling in squin -> stim pass (#394)
Previously, the following would cause the squin to stim pass to fail: ```python for i in range(...): for j in range(...): qubit.apply(op.cx(), [q[i], q[j]) ``` This was because the unroll rule was not successful in fully unrolling the loops, getting (at most) one unroll before silently giving up. This behavior was first identified by @liupengy19 and encountered more recently by @ehua7365 as well. @kaihsin pointed out to me that the problem might be because of a lack of constprop occuring between each unroll. This turns out to be the case and I've implemented a mini-pass that will do the constprop after each unroll, with the pass itself applied via fixpoint instead of the more traditional unsafe_run, single application style. I've also taken the liberty to rename some of the tests to be a bit more accurate and removed some outdated comments. --------- Co-authored-by: David Plankensteiner <[email protected]>
1 parent c0a19d8 commit 680a4f9

File tree

4 files changed

+59
-15
lines changed

4 files changed

+59
-15
lines changed

src/bloqade/stim/passes/squin_to_stim.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22

3-
from kirin.passes import Fold
3+
from kirin.passes import Fold, HintConst
44
from kirin.rewrite import (
55
Walk,
66
Chain,
@@ -38,6 +38,37 @@
3838
from ..rewrite.ifs_to_stim import IfToStim
3939

4040

41+
@dataclass
42+
class AggressiveForLoopUnroll(Pass):
43+
"""
44+
Aggressive unrolling of for loops, addresses cases where unroll
45+
does not successfully handle nested loops because of a lack of constprop.
46+
47+
This should be invoked via fixpoint to let this be repeatedly applied until
48+
no further rewrites are possible.
49+
"""
50+
51+
def unsafe_run(self, mt: Method) -> RewriteResult:
52+
rule = Chain(
53+
InlineGetField(),
54+
InlineGetItem(),
55+
scf.unroll.ForLoop(),
56+
scf.trim.UnusedYield(),
57+
)
58+
59+
# Intentionally only walk ONCE, let fixpoint happen with the WHOLE pass
60+
# so that HintConst gets run right after, allowing subsequent unrolls to happen
61+
rewrite_result = Walk(rule).rewrite(mt.code)
62+
63+
rewrite_result = (
64+
HintConst(dialects=mt.dialects, no_raise=self.no_raise)
65+
.unsafe_run(mt)
66+
.join(rewrite_result)
67+
)
68+
69+
return rewrite_result
70+
71+
4172
@dataclass
4273
class SquinToStimPass(Pass):
4374

@@ -48,15 +79,10 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
4879
dialects=mt.dialects, no_raise=self.no_raise
4980
).unsafe_run(mt)
5081

51-
rule = Chain(
52-
InlineGetField(),
53-
InlineGetItem(),
54-
scf.unroll.ForLoop(),
55-
scf.trim.UnusedYield(),
56-
)
57-
rewrite_result = Fixpoint(Walk(rule)).rewrite(mt.code).join(rewrite_result)
58-
# fold_pass = Fold(mt.dialects, no_raise=self.no_raise)
59-
# rewrite_result = fold_pass(mt)
82+
rewrite_result = AggressiveForLoopUnroll(
83+
dialects=mt.dialects, no_raise=self.no_raise
84+
).fixpoint(mt)
85+
6086
rewrite_result = (
6187
Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result)
6288
)
@@ -66,9 +92,6 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
6692
.join(rewrite_result)
6793
)
6894

69-
# run typeinfer again after unroll etc. because we now insert
70-
# a lot of new nodes, which might have more precise types
71-
# self.typeinfer.unsafe_run(mt)
7295
rewrite_result = (
7396
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
7497
.rewrite(mt.code)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
H 0
3+
CX 0 2
4+
CX 1 2

test/stim/passes/test_squin_qubit_to_stim.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def test():
153153
assert codegen(test).strip() == "SQRT_Y 0"
154154

155155

156-
def test_for_loop_rewrite():
156+
def test_for_loop_nontrivial_index_rewrite():
157157

158158
@squin.kernel
159159
def main():
@@ -164,6 +164,23 @@ def main():
164164
squin.qubit.apply(cx, [q[i], q[i + 1]])
165165

166166
SquinToStimPass(main.dialects)(main)
167-
base_stim_prog = load_reference_program("for_loop.stim")
167+
base_stim_prog = load_reference_program("for_loop_nontrivial_index.stim")
168+
169+
assert codegen(main) == base_stim_prog.rstrip()
170+
171+
172+
def test_nested_for_loop_rewrite():
173+
174+
@squin.kernel
175+
def main():
176+
q = squin.qubit.new(5)
177+
squin.qubit.apply(squin.op.h(), q[0])
178+
cx = squin.op.cx()
179+
for i in range(2):
180+
for j in range(2, 3):
181+
squin.qubit.apply(cx, [q[i], q[j]])
182+
183+
SquinToStimPass(main.dialects)(main)
184+
base_stim_prog = load_reference_program("nested_for_loop.stim")
168185

169186
assert codegen(main) == base_stim_prog.rstrip()

0 commit comments

Comments
 (0)