Skip to content

Commit a28535d

Browse files
david-plweinbe58
andcommitted
Lowering of squin noise to cirq (#343)
Closes #317 --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent 8b7c6cd commit a28535d

File tree

6 files changed

+166
-8
lines changed

6 files changed

+166
-8
lines changed

src/bloqade/squin/cirq/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from .. import kernel
1010

1111
# NOTE: just to register methods
12-
from .emit import op as op, qubit as qubit
12+
from .emit import op as op, noise as noise, qubit as qubit
1313
from .lowering import Squin
14+
from ..noise.rewrite import RewriteNoiseStmts
1415
from .emit.emit_circuit import EmitCirq
1516

1617

@@ -235,7 +236,12 @@ def main():
235236
)
236237

237238
emitter = EmitCirq(qubits=qubits)
238-
return emitter.run(mt, args=())
239+
240+
# Rewrite noise statements
241+
mt_ = mt.similar(mt.dialects)
242+
RewriteNoiseStmts(mt_.dialects)(mt_)
243+
244+
return emitter.run(mt_, args=())
239245

240246

241247
def dump_circuit(mt: ir.Method, qubits: Sequence[cirq.Qid] | None = None, **kwargs):
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import cirq
2+
from kirin.emit import EmitError
3+
from kirin.interp import MethodTable, impl
4+
5+
from ... import noise
6+
from .runtime import (
7+
KronRuntime,
8+
BasicOpRuntime,
9+
OperatorRuntimeABC,
10+
PauliStringRuntime,
11+
)
12+
from .emit_circuit import EmitCirq, EmitCirqFrame
13+
14+
15+
@noise.dialect.register(key="emit.cirq")
16+
class EmitCirqNoiseMethods(MethodTable):
17+
18+
@impl(noise.stmts.StochasticUnitaryChannel)
19+
def stochastic_unitary_channel(
20+
self,
21+
emit: EmitCirq,
22+
frame: EmitCirqFrame,
23+
stmt: noise.stmts.StochasticUnitaryChannel,
24+
):
25+
ops = frame.get(stmt.operators)
26+
ps = frame.get(stmt.probabilities)
27+
28+
error_probabilities = {self._op_to_key(op_): p for op_, p in zip(ops, ps)}
29+
cirq_op = cirq.asymmetric_depolarize(error_probabilities=error_probabilities)
30+
return (BasicOpRuntime(cirq_op),)
31+
32+
@staticmethod
33+
def _op_to_key(operator: OperatorRuntimeABC) -> str:
34+
match operator:
35+
case KronRuntime():
36+
key_lhs = EmitCirqNoiseMethods._op_to_key(operator.lhs)
37+
key_rhs = EmitCirqNoiseMethods._op_to_key(operator.rhs)
38+
return key_lhs + key_rhs
39+
40+
case BasicOpRuntime():
41+
return str(operator.gate)
42+
43+
case PauliStringRuntime():
44+
return operator.string
45+
46+
case _:
47+
raise EmitError(
48+
f"Unexpected operator runtime in StochasticUnitaryChannel of type {type(operator).__name__} encountered!"
49+
)

src/bloqade/squin/cirq/emit/runtime.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ def apply(
2121

2222
def unsafe_apply(
2323
self, qubits: Sequence[cirq.Qid], adjoint: bool = False
24-
) -> list[cirq.Operation]: ...
24+
) -> list[cirq.Operation]:
25+
raise NotImplementedError(
26+
f"Apply method needs to be implemented in {self.__class__.__name__}"
27+
)
2528

2629

2730
@dataclass
@@ -38,6 +41,11 @@ class BasicOpRuntime(UnsafeOperatorRuntimeABC):
3841
def num_qubits(self) -> int:
3942
return self.gate.num_qubits()
4043

44+
def unsafe_apply(
45+
self, qubits: Sequence[cirq.Qid], adjoint: bool = False
46+
) -> list[cirq.Operation]:
47+
return [self.gate(*qubits)]
48+
4149

4250
@dataclass
4351
class UnitaryRuntime(BasicOpRuntime):

src/bloqade/squin/noise/_wrapper.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from typing import Literal
2+
3+
from kirin.dialects import ilist
14
from kirin.lowering import wraps
25

36
from bloqade.squin.op.types import Op
@@ -18,11 +21,15 @@ def depolarize(p: float) -> Op: ...
1821

1922

2023
@wraps(stmts.SingleQubitPauliChannel)
21-
def single_qubit_pauli_channel(params: tuple[float, float, float]) -> Op: ...
24+
def single_qubit_pauli_channel(
25+
params: ilist.IList[float, Literal[3]] | list[float] | tuple[float, float, float],
26+
) -> Op: ...
2227

2328

2429
@wraps(stmts.TwoQubitPauliChannel)
25-
def two_qubit_pauli_channel(params: tuple[float, ...]) -> Op: ...
30+
def two_qubit_pauli_channel(
31+
params: ilist.IList[float, Literal[15]] | list[float] | tuple[float, ...],
32+
) -> Op: ...
2633

2734

2835
@wraps(stmts.QubitLoss)

src/bloqade/squin/noise/rewrite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ def rewrite_single_qubit_pauli_channel(
5858
def rewrite_two_qubit_pauli_channel(
5959
self, node: TwoQubitPauliChannel
6060
) -> RewriteResult:
61-
paulis = (X(), Y(), Z(), Identity(sites=1))
61+
paulis = (Identity(sites=1), X(), Y(), Z())
6262
for op in paulis:
6363
op.insert_before(node)
6464

65-
# NOTE: collect list so we can skip the last entry, which will be two identities
66-
combinations = list(itertools.product(paulis, repeat=2))[:-1]
65+
# NOTE: collect list so we can skip the first entry, which will be two identities
66+
combinations = list(itertools.product(paulis, repeat=2))[1:]
6767
operators: list[ir.SSAValue] = []
6868
for pauli_1, pauli_2 in combinations:
6969
op = Kron(pauli_1.result, pauli_2.result)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import cirq
2+
3+
from bloqade import squin
4+
5+
6+
def test_pauli_channel(run_sim: bool = False):
7+
@squin.kernel
8+
def main():
9+
q = squin.qubit.new(2)
10+
h = squin.op.h()
11+
cx = squin.op.cx()
12+
squin.qubit.apply(h, q[0])
13+
dpl = squin.noise.depolarize(0.1)
14+
squin.qubit.apply(dpl, q[0])
15+
squin.qubit.apply(cx, q)
16+
single_qubit_noise = squin.noise.single_qubit_pauli_channel([0.1, 0.12, 0.13])
17+
squin.qubit.apply(single_qubit_noise, q[1])
18+
two_qubit_noise = squin.noise.two_qubit_pauli_channel(
19+
[
20+
0.036,
21+
0.007,
22+
0.035,
23+
0.022,
24+
0.063,
25+
0.024,
26+
0.006,
27+
0.033,
28+
0.014,
29+
0.019,
30+
0.023,
31+
0.058,
32+
0.0,
33+
0.0,
34+
0.064,
35+
]
36+
)
37+
squin.qubit.apply(two_qubit_noise, q)
38+
squin.qubit.measure(q)
39+
40+
main.print()
41+
42+
circuit = squin.cirq.emit_circuit(main)
43+
44+
print(circuit)
45+
46+
if run_sim:
47+
sim = cirq.Simulator()
48+
sim.run(circuit)
49+
50+
51+
def test_pauli_error(run_sim: bool = False):
52+
@squin.kernel
53+
def main():
54+
q = squin.qubit.new(2)
55+
x = squin.op.x()
56+
n = squin.noise.pauli_error(x, 0.1)
57+
squin.qubit.apply(n, q[0])
58+
squin.qubit.measure(q)
59+
60+
main.print()
61+
62+
circuit = squin.cirq.emit_circuit(main)
63+
64+
print(circuit)
65+
66+
if run_sim:
67+
sim = cirq.Simulator()
68+
sim.run(circuit)
69+
70+
71+
def test_pperror(run_sim: bool = False):
72+
@squin.kernel
73+
def main():
74+
q = squin.qubit.new(3)
75+
ps = squin.op.pauli_string(string="XYZ")
76+
n = squin.noise.pp_error(ps, 0.1)
77+
squin.qubit.apply(n, q)
78+
squin.qubit.measure(q)
79+
80+
main.print()
81+
82+
circuit = squin.cirq.emit_circuit(main)
83+
84+
print(circuit)
85+
86+
if run_sim:
87+
sim = cirq.Simulator()
88+
sim.run(circuit)

0 commit comments

Comments
 (0)