Skip to content

Commit 9f7b388

Browse files
ehua7365johnzl-777
andauthored
Support Squin Depolarize to Stim DEPOLARIZE1 (#367)
Support rewriting squin depolarizing noise to stim depolarizing noise for single qubits. --------- Co-authored-by: John Long <[email protected]>
1 parent bf3d3e9 commit 9f7b388

File tree

7 files changed

+75
-18
lines changed

7 files changed

+75
-18
lines changed

_typos.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ Braket = "Braket"
1313
mch = "mch"
1414
IY = "IY"
1515
ket = "ket"
16+
typ = "typ"

src/bloqade/stim/passes/squin_to_stim.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
DeadCodeElimination,
99
CommonSubexpressionElimination,
1010
)
11-
from kirin.analysis import const
1211
from kirin.ir.method import Method
1312
from kirin.passes.abc import Pass
1413
from kirin.rewrite.abc import RewriteResult
@@ -33,18 +32,11 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
3332
# propagate constants
3433
rewrite_result = fold_pass(mt)
3534

36-
cp_frame, _ = const.Propagate(dialects=mt.dialects).run_analysis(mt)
37-
cp_results = cp_frame.entries
38-
3935
# Assume that address analysis and
4036
# wrapping has been done before this pass!
4137

4238
# Rewrite the noise statements first.
43-
rewrite_result = (
44-
Walk(SquinNoiseToStim(cp_results=cp_results))
45-
.rewrite(mt.code)
46-
.join(rewrite_result)
47-
)
39+
rewrite_result = Walk(SquinNoiseToStim()).rewrite(mt.code).join(rewrite_result)
4840

4941
# Wrap Rewrite + SquinToStim can happen w/ standard walk
5042

src/bloqade/stim/rewrite/squin_noise.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from typing import Dict, Tuple
1+
from typing import Tuple
22
from dataclasses import dataclass
33

44
from kirin.ir import SSAValue, Statement
5-
from kirin.analysis import const
6-
from kirin.dialects import py
5+
from kirin.dialects import py, ilist
76
from kirin.rewrite.abc import RewriteRule, RewriteResult
87

98
from bloqade.squin import op, wire, noise as squin_noise, qubit
109
from bloqade.stim.dialects import noise as stim_noise
1110
from bloqade.stim.rewrite.util import (
11+
get_const_value,
1212
create_wire_passthrough,
1313
insert_qubit_idx_after_apply,
1414
)
@@ -17,8 +17,6 @@
1717
@dataclass
1818
class SquinNoiseToStim(RewriteRule):
1919

20-
cp_results: Dict[SSAValue, const.Result]
21-
2220
def rewrite_Statement(self, node: Statement) -> RewriteResult:
2321
match node:
2422
case qubit.Apply() | qubit.Broadcast():
@@ -67,7 +65,7 @@ def rewrite_PauliError(
6765
assert isinstance(squin_channel, squin_noise.stmts.PauliError)
6866
basis = squin_channel.basis.owner
6967
assert isinstance(basis, op.stmts.PauliOp)
70-
p = self.cp_results.get(squin_channel.p).data
68+
p = get_const_value(float, squin_channel.p)
7169

7270
p_stmt = py.Constant(p)
7371
p_stmt.insert_before(stmt)
@@ -90,7 +88,7 @@ def rewrite_SingleQubitPauliChannel(
9088
squin_channel = stmt.operator.owner
9189
assert isinstance(squin_channel, squin_noise.stmts.SingleQubitPauliChannel)
9290

93-
params = self.cp_results.get(squin_channel.params).data
91+
params = get_const_value(ilist.IList, squin_channel.params)
9492
new_stmts = [
9593
p_x := py.Constant(params[0]),
9694
p_y := py.Constant(params[1]),
@@ -117,7 +115,7 @@ def rewrite_TwoQubitPauliChannel(
117115
squin_channel = stmt.operator.owner
118116
assert isinstance(squin_channel, squin_noise.stmts.TwoQubitPauliChannel)
119117

120-
params = self.cp_results.get(squin_channel.params).data
118+
params = get_const_value(ilist.IList, squin_channel.params)
121119
param_stmts = [py.Constant(p) for p in params]
122120
for param_stmt in param_stmts:
123121
param_stmt.insert_before(stmt)
@@ -141,3 +139,20 @@ def rewrite_TwoQubitPauliChannel(
141139
pzz=param_stmts[14].result,
142140
)
143141
return stim_stmt
142+
143+
def rewrite_Depolarize(
144+
self,
145+
stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
146+
qubit_idx_ssas: Tuple[SSAValue],
147+
) -> Statement:
148+
"""Rewrite squin.noise.Depolarize to stim.Depolarize1."""
149+
150+
squin_channel = stmt.operator.owner
151+
assert isinstance(squin_channel, squin_noise.stmts.Depolarize)
152+
153+
p = get_const_value(float, squin_channel.p)
154+
p_stmt = py.Constant(p)
155+
p_stmt.insert_before(stmt)
156+
157+
stim_stmt = stim_noise.Depolarize1(targets=qubit_idx_ssas, p=p_stmt.result)
158+
return stim_stmt

src/bloqade/stim/rewrite/util.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from kirin import ir
1+
from typing import TypeVar
2+
3+
from kirin import ir, interp
4+
from kirin.analysis import const
25
from kirin.dialects import py
36
from kirin.rewrite.abc import RewriteResult
47

@@ -205,3 +208,19 @@ def is_measure_result_used(
205208
Check if the result of a measure statement is used in the program.
206209
"""
207210
return bool(stmt.result.uses)
211+
212+
213+
T = TypeVar("T")
214+
215+
216+
def get_const_value(typ: type[T], value: ir.SSAValue) -> T:
217+
if isinstance(hint := value.hints.get("const"), const.Value):
218+
data = hint.data
219+
if isinstance(data, typ):
220+
return hint.data
221+
raise interp.InterpreterError(
222+
f"Expected constant value <type = {typ}>, got {data}"
223+
)
224+
raise interp.InterpreterError(
225+
f"Expected constant value <type = {typ}>, got {value}"
226+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
DEPOLARIZE1(0.01000000) 0
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
DEPOLARIZE1(0.01000000) 0 1 2 3

test/stim/passes/test_squin_noise_to_stim.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,34 @@ def test():
150150
assert codegen(test) == expected_stim_program
151151

152152

153+
def test_apply_depolarize1():
154+
155+
@kernel
156+
def test():
157+
q = qubit.new(1)
158+
channel = noise.depolarize(p=0.01)
159+
qubit.apply(channel, q[0])
160+
return
161+
162+
run_address_and_stim_passes(test)
163+
expected_stim_program = load_reference_program("apply_depolarize1.stim")
164+
assert codegen(test) == expected_stim_program
165+
166+
167+
def test_broadcast_depolarize1():
168+
169+
@kernel
170+
def test():
171+
q = qubit.new(4)
172+
channel = noise.depolarize(p=0.01)
173+
qubit.broadcast(channel, q)
174+
return
175+
176+
run_address_and_stim_passes(test)
177+
expected_stim_program = load_reference_program("broadcast_depolarize1.stim")
178+
assert codegen(test) == expected_stim_program
179+
180+
153181
def test_broadcast_iid_bit_flip_channel():
154182

155183
@kernel

0 commit comments

Comments
 (0)