Skip to content

Commit 7ae0f1f

Browse files
committed
Implement runtime for squin noise statements (#273)
Closes #227
1 parent d31a457 commit 7ae0f1f

File tree

11 files changed

+380
-24
lines changed

11 files changed

+380
-24
lines changed

src/bloqade/pyqrack/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# NOTE: The following import is for registering the method tables
1717
from .noise import native as native
1818
from .qasm2 import uop as uop, core as core, glob as glob, parallel as parallel
19-
from .squin import op as op, qubit as qubit
19+
from .squin import op as op, noise as noise, qubit as qubit
2020
from .device import (
2121
StackMemorySimulator as StackMemorySimulator,
2222
DynamicMemorySimulator as DynamicMemorySimulator,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import native as native
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import random
2+
import typing
3+
from dataclasses import dataclass
4+
5+
from kirin import interp
6+
from kirin.dialects import ilist
7+
8+
from bloqade.pyqrack import QubitState, PyQrackQubit, PyQrackInterpreter
9+
from bloqade.squin.noise.stmts import QubitLoss, StochasticUnitaryChannel
10+
from bloqade.squin.noise._dialect import dialect as squin_noise_dialect
11+
12+
from ..runtime import OperatorRuntimeABC
13+
14+
15+
@dataclass(frozen=True)
16+
class StochasticUnitaryChannelRuntime(OperatorRuntimeABC):
17+
operators: ilist.IList[OperatorRuntimeABC, typing.Any]
18+
probabilities: ilist.IList[float, typing.Any]
19+
20+
@property
21+
def n_sites(self) -> int:
22+
n = self.operators[0].n_sites
23+
for op in self.operators[1:]:
24+
assert (
25+
op.n_sites == n
26+
), "Encountered a stochastic unitary channel with operators of different size!"
27+
return n
28+
29+
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
30+
# NOTE: probabilities don't necessarily sum to 1; could be no noise event should occur
31+
p_no_op = 1 - sum(self.probabilities)
32+
if random.uniform(0.0, 1.0) < p_no_op:
33+
return
34+
35+
selected_ops = random.choices(self.operators, weights=self.probabilities)
36+
for op in selected_ops:
37+
op.apply(*qubits, adjoint=adjoint)
38+
39+
40+
@dataclass(frozen=True)
41+
class QubitLossRuntime(OperatorRuntimeABC):
42+
p: float
43+
44+
@property
45+
def n_sites(self) -> int:
46+
return 1
47+
48+
def apply(self, qubit: PyQrackQubit, adjoint: bool = False) -> None:
49+
if random.uniform(0.0, 1.0) < self.p:
50+
qubit.state = QubitState.Lost
51+
52+
53+
@squin_noise_dialect.register(key="pyqrack")
54+
class PyQrackMethods(interp.MethodTable):
55+
@interp.impl(StochasticUnitaryChannel)
56+
def stochastic_unitary_channel(
57+
self,
58+
interp: PyQrackInterpreter,
59+
frame: interp.Frame,
60+
stmt: StochasticUnitaryChannel,
61+
):
62+
operators = frame.get(stmt.operators)
63+
probabilities = frame.get(stmt.probabilities)
64+
65+
return (StochasticUnitaryChannelRuntime(operators, probabilities),)
66+
67+
@interp.impl(QubitLoss)
68+
def qubit_loss(
69+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: QubitLoss
70+
):
71+
p = frame.get(stmt.p)
72+
return (QubitLossRuntime(p),)

src/bloqade/squin/groups.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from kirin.rewrite import Walk, Chain
44
from kirin.dialects import ilist
55

6-
from . import op, wire, qubit
6+
from . import op, wire, noise, qubit
77
from .op.rewrite import PyMultToSquinMult
88
from .rewrite.desugar import ApplyDesugarRule, MeasureDesugarRule
99

1010

11-
@ir.dialect_group(structural_no_opt.union([op, qubit]))
11+
@ir.dialect_group(structural_no_opt.union([op, qubit, noise]))
1212
def kernel(self):
1313
fold_pass = passes.Fold(self)
1414
typeinfer_pass = passes.TypeInfer(self)
@@ -36,7 +36,7 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
3636
return run_pass
3737

3838

39-
@ir.dialect_group(structural_no_opt.union([op, wire]))
39+
@ir.dialect_group(structural_no_opt.union([op, wire, noise]))
4040
def wired(self):
4141
py_mult_to_mult_pass = PyMultToSquinMult(self)
4242

src/bloqade/squin/noise/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,7 @@
44
pp_error as pp_error,
55
depolarize as depolarize,
66
qubit_loss as qubit_loss,
7-
pauli_channel as pauli_channel,
7+
pauli_error as pauli_error,
8+
two_qubit_pauli_channel as two_qubit_pauli_channel,
9+
single_qubit_pauli_channel as single_qubit_pauli_channel,
810
)

src/bloqade/squin/noise/_wrapper.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@ def pp_error(op: Op, p: float) -> Op: ...
1414

1515

1616
@wraps(stmts.Depolarize)
17-
def depolarize(n_qubits: int, p: float) -> Op: ...
17+
def depolarize(p: float) -> Op: ...
1818

1919

20-
@wraps(stmts.PauliChannel)
21-
def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
20+
@wraps(stmts.SingleQubitPauliChannel)
21+
def single_qubit_pauli_channel(params: tuple[float, float, float]) -> Op: ...
22+
23+
24+
@wraps(stmts.TwoQubitPauliChannel)
25+
def two_qubit_pauli_channel(params: tuple[float, ...]) -> Op: ...
2226

2327

2428
@wraps(stmts.QubitLoss)

src/bloqade/squin/noise/rewrite.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import itertools
2+
3+
from kirin import ir
4+
from kirin.passes import Pass
5+
from kirin.rewrite import Walk
6+
from kirin.dialects import ilist
7+
from kirin.rewrite.abc import RewriteRule, RewriteResult
8+
9+
from .stmts import (
10+
PPError,
11+
QubitLoss,
12+
Depolarize,
13+
PauliError,
14+
NoiseChannel,
15+
TwoQubitPauliChannel,
16+
SingleQubitPauliChannel,
17+
StochasticUnitaryChannel,
18+
)
19+
from ..op.stmts import X, Y, Z, Kron, Identity
20+
21+
22+
class _RewriteNoiseStmts(RewriteRule):
23+
"""Rewrites squin noise statements to StochasticUnitaryChannel"""
24+
25+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
26+
if not isinstance(node, NoiseChannel) or isinstance(node, QubitLoss):
27+
return RewriteResult()
28+
29+
return getattr(self, "rewrite_" + node.name)(node)
30+
31+
def rewrite_pauli_error(self, node: PauliError) -> RewriteResult:
32+
(operators := ilist.New(values=(node.basis,))).insert_before(node)
33+
(ps := ilist.New(values=(node.p,))).insert_before(node)
34+
stochastic_channel = StochasticUnitaryChannel(
35+
operators=operators.result, probabilities=ps.result
36+
)
37+
38+
node.replace_by(stochastic_channel)
39+
return RewriteResult(has_done_something=True)
40+
41+
def rewrite_single_qubit_pauli_channel(
42+
self, node: SingleQubitPauliChannel
43+
) -> RewriteResult:
44+
paulis = (X(), Y(), Z())
45+
paulis_ssa: list[ir.SSAValue] = []
46+
for op in paulis:
47+
op.insert_before(node)
48+
paulis_ssa.append(op.result)
49+
50+
(pauli_ops := ilist.New(values=paulis_ssa)).insert_before(node)
51+
52+
stochastic_unitary = StochasticUnitaryChannel(
53+
operators=pauli_ops.result, probabilities=node.params
54+
)
55+
node.replace_by(stochastic_unitary)
56+
return RewriteResult(has_done_something=True)
57+
58+
def rewrite_two_qubit_pauli_channel(
59+
self, node: TwoQubitPauliChannel
60+
) -> RewriteResult:
61+
paulis = (X(), Y(), Z(), Identity(sites=1))
62+
for op in paulis:
63+
op.insert_before(node)
64+
65+
# NOTE: collect list so we can skip the last entry, which will be two identities
66+
combinations = list(itertools.product(paulis, repeat=2))[:-1]
67+
operators: list[ir.SSAValue] = []
68+
for pauli_1, pauli_2 in combinations:
69+
op = Kron(pauli_1.result, pauli_2.result)
70+
op.insert_before(node)
71+
operators.append(op.result)
72+
73+
(operator_list := ilist.New(values=operators)).insert_before(node)
74+
stochastic_unitary = StochasticUnitaryChannel(
75+
operators=operator_list.result, probabilities=node.params
76+
)
77+
78+
node.replace_by(stochastic_unitary)
79+
return RewriteResult(has_done_something=True)
80+
81+
def rewrite_p_p_error(self, node: PPError) -> RewriteResult:
82+
(operators := ilist.New(values=(node.op,))).insert_before(node)
83+
(ps := ilist.New(values=(node.p,))).insert_before(node)
84+
stochastic_channel = StochasticUnitaryChannel(
85+
operators=operators.result, probabilities=ps.result
86+
)
87+
88+
node.replace_by(stochastic_channel)
89+
return RewriteResult(has_done_something=True)
90+
91+
def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
92+
paulis = (X(), Y(), Z())
93+
operators: list[ir.SSAValue] = []
94+
for op in paulis:
95+
op.insert_before(node)
96+
operators.append(op.result)
97+
98+
(operator_list := ilist.New(values=operators)).insert_before(node)
99+
(ps := ilist.New(values=[node.p for _ in range(3)])).insert_before(node)
100+
101+
stochastic_unitary = StochasticUnitaryChannel(
102+
operators=operator_list.result, probabilities=ps.result
103+
)
104+
node.replace_by(stochastic_unitary)
105+
106+
return RewriteResult(has_done_something=True)
107+
108+
109+
class RewriteNoiseStmts(Pass):
110+
def unsafe_run(self, mt: ir.Method):
111+
return Walk(_RewriteNoiseStmts()).rewrite(mt.code)

src/bloqade/squin/noise/stmts.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
1-
from kirin import ir, types
1+
from kirin import ir, types, lowering
22
from kirin.decl import info, statement
3+
from kirin.dialects import ilist
34

45
from bloqade.squin.op.types import OpType
56

67
from ._dialect import dialect
8+
from ..op.types import NumOperators
79

810

911
@statement
1012
class NoiseChannel(ir.Statement):
11-
pass
13+
traits = frozenset({lowering.FromPythonCall()})
14+
result: ir.ResultValue = info.result(OpType)
1215

1316

1417
@statement(dialect=dialect)
1518
class PauliError(NoiseChannel):
1619
basis: ir.SSAValue = info.argument(OpType)
1720
p: ir.SSAValue = info.argument(types.Float)
18-
result: ir.ResultValue = info.result(OpType)
1921

2022

2123
@statement(dialect=dialect)
@@ -26,34 +28,37 @@ class PPError(NoiseChannel):
2628

2729
op: ir.SSAValue = info.argument(OpType)
2830
p: ir.SSAValue = info.argument(types.Float)
29-
result: ir.ResultValue = info.result(OpType)
3031

3132

3233
@statement(dialect=dialect)
3334
class Depolarize(NoiseChannel):
3435
"""
35-
Apply n-qubit depolaize error to qubits
36-
NOTE For Stim, this can only accept 1 or 2 qubits
36+
Apply depolarize error to qubit
3737
"""
3838

39-
n_qubits: int = info.attribute(types.Int)
4039
p: ir.SSAValue = info.argument(types.Float)
41-
result: ir.ResultValue = info.result(OpType)
4240

4341

4442
@statement(dialect=dialect)
45-
class PauliChannel(NoiseChannel):
46-
# NOTE:
47-
# 1-qubit 3 params px, py, pz
48-
# 2-qubit 15 params pix, piy, piz, pxi, pxx, pxy, pxz, pyi, pyx ..., pzz
49-
# TODO add validation for params (maybe during lowering via custom lower?)
50-
n_qubits: int = info.attribute()
51-
params: ir.SSAValue = info.argument(types.Tuple[types.Vararg(types.Float)])
52-
result: ir.ResultValue = info.result(OpType)
43+
class SingleQubitPauliChannel(NoiseChannel):
44+
params: ir.SSAValue = info.argument(ilist.IListType[types.Float, types.Literal(3)])
45+
46+
47+
@statement(dialect=dialect)
48+
class TwoQubitPauliChannel(NoiseChannel):
49+
params: ir.SSAValue = info.argument(ilist.IListType[types.Float, types.Literal(15)])
5350

5451

5552
@statement(dialect=dialect)
5653
class QubitLoss(NoiseChannel):
5754
# NOTE: qubit loss error (not supported by Stim)
5855
p: ir.SSAValue = info.argument(types.Float)
56+
57+
58+
@statement(dialect=dialect)
59+
class StochasticUnitaryChannel(ir.Statement):
60+
operators: ir.SSAValue = info.argument(ilist.IListType[OpType, NumOperators])
61+
probabilities: ir.SSAValue = info.argument(
62+
ilist.IListType[types.Float, NumOperators]
63+
)
5964
result: ir.ResultValue = info.result(OpType)

src/bloqade/squin/op/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ def __rmul__(self, other: complex) -> "Op":
2222

2323

2424
OpType = types.PyClass(Op)
25+
26+
NumOperators = types.TypeVar("NumOperators")
File renamed without changes.

0 commit comments

Comments
 (0)