Skip to content

Commit 5425cfd

Browse files
authored
Fix qubit.apply lowering with multiple qubit indexing operations (#381)
When more than one qubit indexing operation is present, the lowering for `squin.qubit.apply` falls through which causes issues in the squin to stim infrastructure. To be precise, a statement like: ```python qubit.apply(op.some_gate, [list_of_qubits[i]]) ``` Is fine but @liupengy19 helped identify that something like: ```python qubit.apply(op.some_gate, [q[i], q[i+1]) ``` causes problems. With help from @david-pl and @kaihsin I was able to figure out that some of the existing lowering logic just needed to be made a bit more flexible (:
1 parent 862197a commit 5425cfd

File tree

5 files changed

+84
-45
lines changed

5 files changed

+84
-45
lines changed

src/bloqade/squin/rewrite/desugar.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
5454
qubits = node.qubits
5555

5656
if len(qubits) > 1 and all(q.type.is_subseteq(QubitType) for q in qubits):
57-
(qubits_ilist_stmt := ilist.New(qubits)).insert_before(node)
57+
(qubits_ilist_stmt := ilist.New(qubits)).insert_before(
58+
node
59+
) # qubits is just a tuple of SSAValues
5860
qubits_ilist = qubits_ilist_stmt.result
5961

6062
elif len(qubits) == 1 and qubits[0].type.is_subseteq(QubitType):
@@ -76,34 +78,44 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
7678
return RewriteResult()
7779

7880
is_ilist = isinstance(qbit_stmt := qubits[0].stmt, ilist.New)
81+
7982
if is_ilist:
80-
if len(qbit_stmt.values) != 1:
81-
return RewriteResult()
8283

83-
if not isinstance(
84-
qbit_getindex_result := qbit_stmt.values[0], ir.ResultValue
84+
if not all(
85+
isinstance(qbit_getindex_result, ir.ResultValue)
86+
for qbit_getindex_result in qbit_stmt.values
8587
):
8688
return RewriteResult()
8789

88-
qbit_getindex = qbit_getindex_result.stmt
90+
# Get the parent statement that the qubit came from
91+
# (should be a GetItem instance, see logic below)
92+
qbit_getindices = [
93+
qbit_getindex_result.stmt
94+
for qbit_getindex_result in qbit_stmt.values
95+
]
8996
else:
90-
qbit_getindex = qubits[0].stmt
97+
qbit_getindices = [qubit.stmt for qubit in qubits]
9198

92-
if not isinstance(qbit_getindex, py.indexing.GetItem):
99+
if any(
100+
not isinstance(qbit_getindex, py.indexing.GetItem)
101+
for qbit_getindex in qbit_getindices
102+
):
93103
return RewriteResult()
94104

95-
if not qbit_getindex.obj.type.is_subseteq(
96-
ilist.IListType[QubitType, types.Any]
105+
# The GetItem should have been applied on something that returns an IList of Qubits
106+
if any(
107+
not qbit_getindex.obj.type.is_subseteq(
108+
ilist.IListType[QubitType, types.Any]
109+
)
110+
for qbit_getindex in qbit_getindices
97111
):
98112
return RewriteResult()
99113

100114
if is_ilist:
101-
values = qbit_stmt.values
115+
qubits_ilist = qbit_stmt.result
102116
else:
103-
values = [qubits[0]]
104-
105-
(qubits_ilist_stmt := ilist.New(values=values)).insert_before(node)
106-
qubits_ilist = qubits_ilist_stmt.result
117+
(qubits_ilist_stmt := ilist.New(values=[qubits[0]])).insert_before(node)
118+
qubits_ilist = qubits_ilist_stmt.result
107119
else:
108120
return RewriteResult()
109121

test/squin/test_sugar.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def main():
4646
h = squin.op.h()
4747
x = squin.op.x()
4848

49-
# test applying to lest with getindex
49+
# test applying to list with getindex
5050
squin.qubit.apply(x, [q[0]])
5151

5252
# test apply with ast.Name
@@ -89,3 +89,18 @@ def main():
8989

9090
assert math.isclose(abs(ket[0]) ** 2, 1, abs_tol=1e-7)
9191
assert ket[1] == ket[2] == ket[3] == 0
92+
93+
94+
def test_apply_in_for_loop_index_multiple_index():
95+
96+
@squin.kernel
97+
def main():
98+
q = squin.qubit.new(3)
99+
squin.qubit.apply(squin.op.h(), q[0])
100+
cx = squin.op.cx()
101+
for i in range(2):
102+
squin.qubit.apply(cx, [q[i], q[i + 1]])
103+
104+
sim = StackMemorySimulator(min_qubits=3)
105+
ket = sim.state_vector(main)
106+
assert math.isclose(abs(ket[0]) ** 2, 0.5, abs_tol=1e-5)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
H 0
3+
CX 0 1
4+
CX 1 2

test/stim/passes/test_squin_noise_to_stim.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from kirin import ir
44

55
from bloqade.squin import op, noise, qubit, kernel
6+
from bloqade.stim.passes import SquinToStimPass
67

7-
from .test_squin_qubit_to_stim import codegen as _codegen, run_address_and_stim_passes
8+
from .test_squin_qubit_to_stim import codegen as _codegen
89

910

1011
def codegen(mt: ir.Method) -> str:
@@ -30,7 +31,7 @@ def test():
3031
qubit.apply(channel, q[0])
3132
return
3233

33-
run_address_and_stim_passes(test)
34+
SquinToStimPass(test.dialects)(test)
3435
expected_stim_program = load_reference_program("apply_pauli_channel_1.stim")
3536
assert codegen(test) == expected_stim_program
3637

@@ -44,7 +45,7 @@ def test():
4445
qubit.broadcast(channel, q)
4546
return
4647

47-
run_address_and_stim_passes(test)
48+
SquinToStimPass(test.dialects)(test)
4849
expected_stim_program = load_reference_program("broadcast_pauli_channel_1.stim")
4950
assert codegen(test) == expected_stim_program
5051

@@ -58,7 +59,7 @@ def test():
5859
qubit.broadcast(channel, q)
5960
return
6061

61-
run_address_and_stim_passes(test)
62+
SquinToStimPass(test.dialects)(test)
6263
expected_stim_program = load_reference_program(
6364
"broadcast_pauli_channel_1_many_qubits.stim"
6465
)
@@ -76,7 +77,7 @@ def test():
7677
qubit.broadcast(channel, q)
7778
return
7879

79-
run_address_and_stim_passes(test)
80+
SquinToStimPass(test.dialects)(test)
8081
expected_stim_program = load_reference_program(
8182
"broadcast_pauli_channel_1_reuse.stim"
8283
)
@@ -110,7 +111,7 @@ def test():
110111
qubit.broadcast(channel, q)
111112
return
112113

113-
run_address_and_stim_passes(test)
114+
SquinToStimPass(test.dialects)(test)
114115
expected_stim_program = load_reference_program("broadcast_pauli_channel_2.stim")
115116
assert codegen(test) == expected_stim_program
116117

@@ -143,7 +144,7 @@ def test():
143144
qubit.broadcast(channel, [q[2], q[3]])
144145
return
145146

146-
run_address_and_stim_passes(test)
147+
SquinToStimPass(test.dialects)(test)
147148
expected_stim_program = load_reference_program(
148149
"broadcast_pauli_channel_2_reuse_on_4_qubits.stim"
149150
)
@@ -159,7 +160,7 @@ def test():
159160
qubit.broadcast(channel, q)
160161
return
161162

162-
run_address_and_stim_passes(test)
163+
SquinToStimPass(test.dialects)(test)
163164
expected_stim_program = load_reference_program("broadcast_depolarize2.stim")
164165
assert codegen(test) == expected_stim_program
165166

@@ -173,7 +174,7 @@ def test():
173174
qubit.apply(channel, q[0])
174175
return
175176

176-
run_address_and_stim_passes(test)
177+
SquinToStimPass(test.dialects)(test)
177178
expected_stim_program = load_reference_program("apply_depolarize1.stim")
178179
assert codegen(test) == expected_stim_program
179180

@@ -187,7 +188,7 @@ def test():
187188
qubit.broadcast(channel, q)
188189
return
189190

190-
run_address_and_stim_passes(test)
191+
SquinToStimPass(test.dialects)(test)
191192
expected_stim_program = load_reference_program("broadcast_depolarize1.stim")
192193
assert codegen(test) == expected_stim_program
193194

@@ -202,7 +203,7 @@ def test():
202203
qubit.broadcast(channel, q)
203204
return
204205

205-
run_address_and_stim_passes(test)
206+
SquinToStimPass(test.dialects)(test)
206207
expected_stim_program = load_reference_program(
207208
"broadcast_iid_bit_flip_channel.stim"
208209
)
@@ -219,7 +220,7 @@ def test():
219220
qubit.broadcast(channel, q)
220221
return
221222

222-
run_address_and_stim_passes(test)
223+
SquinToStimPass(test.dialects)(test)
223224
expected_stim_program = load_reference_program(
224225
"broadcast_iid_phase_flip_channel.stim"
225226
)
@@ -236,6 +237,6 @@ def test():
236237
qubit.broadcast(channel, q)
237238
return
238239

239-
run_address_and_stim_passes(test)
240+
SquinToStimPass(test.dialects)(test)
240241
expected_stim_program = load_reference_program("broadcast_iid_y_flip_channel.stim")
241242
assert codegen(test) == expected_stim_program

test/stim/passes/test_squin_qubit_to_stim.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@
22
import math
33

44
from kirin import ir
5-
from kirin.rewrite import Walk
65
from kirin.dialects import py
76

87
from bloqade import squin
98
from bloqade.squin import op, noise, qubit, kernel
109
from bloqade.stim.emit import EmitStimMain
1110
from bloqade.stim.passes import SquinToStimPass
12-
from bloqade.squin.rewrite import WrapAddressAnalysis
13-
from bloqade.analysis.address import AddressAnalysis
1411

1512

1613
# Taken gratuitously from Kai's unit test
@@ -38,12 +35,6 @@ def load_reference_program(filename):
3835
return f.read()
3936

4037

41-
def run_address_and_stim_passes(test: ir.Method):
42-
addr_frame, _ = AddressAnalysis(test.dialects).run_analysis(test)
43-
Walk(WrapAddressAnalysis(address_analysis=addr_frame.entries)).rewrite(test.code)
44-
SquinToStimPass(test.dialects)(test)
45-
46-
4738
def test_qubit():
4839
@kernel
4940
def test():
@@ -57,7 +48,7 @@ def test():
5748
squin.qubit.measure(ql)
5849
return
5950

60-
run_address_and_stim_passes(test)
51+
SquinToStimPass(test.dialects)(test)
6152
base_stim_prog = load_reference_program("qubit.stim")
6253

6354
assert codegen(test) == base_stim_prog.rstrip()
@@ -74,7 +65,7 @@ def test():
7465
squin.qubit.measure(q[0])
7566
return
7667

77-
run_address_and_stim_passes(test)
68+
SquinToStimPass(test.dialects)(test)
7869
base_stim_prog = load_reference_program("qubit_reset.stim")
7970

8071
assert codegen(test) == base_stim_prog.rstrip()
@@ -91,7 +82,7 @@ def test():
9182
squin.qubit.measure(ql)
9283
return
9384

94-
run_address_and_stim_passes(test)
85+
SquinToStimPass(test.dialects)(test)
9586
base_stim_prog = load_reference_program("qubit_broadcast.stim")
9687

9788
assert codegen(test) == base_stim_prog.rstrip()
@@ -111,7 +102,7 @@ def test():
111102
squin.qubit.measure(ql)
112103
return
113104

114-
run_address_and_stim_passes(test)
105+
SquinToStimPass(test.dialects)(test)
115106
base_stim_prog = load_reference_program("qubit_loss.stim")
116107

117108
assert codegen(test) == base_stim_prog.rstrip()
@@ -129,7 +120,7 @@ def test():
129120
squin.qubit.measure(q)
130121
return
131122

132-
run_address_and_stim_passes(test)
123+
SquinToStimPass(test.dialects)(test)
133124

134125
base_stim_prog = load_reference_program("u3_to_clifford.stim")
135126

@@ -144,7 +135,7 @@ def test():
144135
qubit.broadcast(op.sqrt_x(), q)
145136
return
146137

147-
run_address_and_stim_passes(test)
138+
SquinToStimPass(test.dialects)(test)
148139

149140
assert codegen(test).strip() == "SQRT_X 0"
150141

@@ -157,6 +148,22 @@ def test():
157148
qubit.broadcast(op.sqrt_y(), q)
158149
return
159150

160-
run_address_and_stim_passes(test)
151+
SquinToStimPass(test.dialects)(test)
161152

162153
assert codegen(test).strip() == "SQRT_Y 0"
154+
155+
156+
def test_for_loop_rewrite():
157+
158+
@squin.kernel
159+
def main():
160+
q = squin.qubit.new(3)
161+
squin.qubit.apply(squin.op.h(), q[0])
162+
cx = squin.op.cx()
163+
for i in range(2):
164+
squin.qubit.apply(cx, [q[i], q[i + 1]])
165+
166+
SquinToStimPass(main.dialects)(main)
167+
base_stim_prog = load_reference_program("for_loop.stim")
168+
169+
assert codegen(main) == base_stim_prog.rstrip()

0 commit comments

Comments
 (0)