Skip to content

Commit e973ebe

Browse files
authored
Refactor & simplify noise dialect (#512)
1 parent 037a7e5 commit e973ebe

32 files changed

+597
-871
lines changed

src/bloqade/cirq_utils/emit/noise.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,6 @@
1616
@noise.dialect.register(key="emit.cirq")
1717
class EmitCirqNoiseMethods(MethodTable):
1818

19-
@impl(noise.stmts.PauliError)
20-
def pauli_error(
21-
self, interp: EmitCirq, frame: EmitCirqFrame, stmt: noise.stmts.PauliError
22-
):
23-
op = frame.get(stmt.basis)
24-
p = frame.get(stmt.p)
25-
error_probabilities = {self._op_to_key(op): p}
26-
gate = cirq.asymmetric_depolarize(error_probabilities=error_probabilities)
27-
return (BasicOpRuntime(gate=gate),)
28-
2919
@impl(noise.stmts.Depolarize)
3020
def depolarize(
3121
self, interp: EmitCirq, frame: EmitCirqFrame, stmt: noise.stmts.Depolarize
@@ -72,20 +62,6 @@ def two_qubit_pauli_channel(
7262
gate = cirq.asymmetric_depolarize(error_probabilities=error_probabilities)
7363
return (BasicOpRuntime(gate),)
7464

75-
@impl(noise.stmts.StochasticUnitaryChannel)
76-
def stochastic_unitary_channel(
77-
self,
78-
emit: EmitCirq,
79-
frame: EmitCirqFrame,
80-
stmt: noise.stmts.StochasticUnitaryChannel,
81-
):
82-
ops = frame.get(stmt.operators)
83-
ps = frame.get(stmt.probabilities)
84-
85-
error_probabilities = {self._op_to_key(op_): p for op_, p in zip(ops, ps)}
86-
cirq_op = cirq.asymmetric_depolarize(error_probabilities=error_probabilities)
87-
return (BasicOpRuntime(cirq_op),)
88-
8965
@staticmethod
9066
def _op_to_key(operator: OperatorRuntimeABC) -> str:
9167
match operator:
Lines changed: 74 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,57 @@
1-
import random
2-
import typing
3-
from functools import cached_property
4-
from dataclasses import dataclass
5-
61
from kirin import interp
7-
from kirin.dialects import ilist
82

9-
from bloqade.pyqrack import QubitState, PyQrackQubit, PyQrackInterpreter
3+
from bloqade.pyqrack import PyQrackQubit, PyQrackInterpreter
104
from bloqade.squin.noise.stmts import (
115
QubitLoss,
126
Depolarize,
13-
PauliError,
147
Depolarize2,
158
TwoQubitPauliChannel,
169
SingleQubitPauliChannel,
17-
StochasticUnitaryChannel,
1810
)
1911
from bloqade.squin.noise._dialect import dialect as squin_noise_dialect
2012

21-
from ..runtime import KronRuntime, IdentityRuntime, OperatorRuntime, OperatorRuntimeABC
22-
23-
24-
@dataclass(frozen=True)
25-
class StochasticUnitaryChannelRuntime(OperatorRuntimeABC):
26-
operators: (
27-
ilist.IList[OperatorRuntimeABC, typing.Any] | tuple[OperatorRuntimeABC, ...]
28-
)
29-
probabilities: ilist.IList[float, typing.Any] | tuple[float, ...]
30-
31-
@property
32-
def n_sites(self) -> int:
33-
n = self.operators[0].n_sites
34-
for op in self.operators[1:]:
35-
assert (
36-
op.n_sites == n
37-
), "Encountered a stochastic unitary channel with operators of different size!"
38-
return n
39-
40-
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
41-
# NOTE: probabilities don't necessarily sum to 1; could be no noise event should occur
42-
p_no_op = 1 - sum(self.probabilities)
43-
if random.uniform(0.0, 1.0) < p_no_op:
44-
return
45-
46-
selected_ops = random.choices(self.operators, weights=self.probabilities)
47-
for op in selected_ops:
48-
op.apply(*qubits, adjoint=adjoint)
49-
50-
51-
@dataclass(frozen=True)
52-
class QubitLossRuntime(OperatorRuntimeABC):
53-
p: float
54-
55-
@property
56-
def n_sites(self) -> int:
57-
return 1
58-
59-
def apply(self, qubit: PyQrackQubit, adjoint: bool = False) -> None:
60-
if random.uniform(0.0, 1.0) <= self.p:
61-
qubit.state = QubitState.Lost
62-
6313

6414
@squin_noise_dialect.register(key="pyqrack")
6515
class PyQrackMethods(interp.MethodTable):
66-
@interp.impl(PauliError)
67-
def pauli_error(
68-
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: PauliError
69-
):
70-
op = frame.get(stmt.basis)
71-
p = frame.get(stmt.p)
72-
return (StochasticUnitaryChannelRuntime((op,), (p,)),)
16+
17+
single_pauli_choices = ("i", "x", "y", "z")
18+
two_pauli_choices = (
19+
"ii",
20+
"ix",
21+
"iy",
22+
"iz",
23+
"xi",
24+
"xx",
25+
"xy",
26+
"xz",
27+
"yi",
28+
"yx",
29+
"yy",
30+
"yz",
31+
"zi",
32+
"zx",
33+
"zy",
34+
"zz",
35+
)
7336

7437
@interp.impl(Depolarize)
7538
def depolarize(
7639
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: Depolarize
7740
):
7841
p = frame.get(stmt.p)
79-
ps = (p / 3.0,) * 3
80-
ops = self.single_qubit_paulis
81-
return (StochasticUnitaryChannelRuntime(ops, ps),)
42+
ps = [p / 3.0] * 3
43+
qubits = frame.get(stmt.qubits)
44+
self.apply_single_qubit_pauli_error(interp, ps, qubits)
8245

8346
@interp.impl(Depolarize2)
8447
def depolarize2(
8548
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: Depolarize2
8649
):
8750
p = frame.get(stmt.p)
88-
ps = (p / 15.0,) * 15
89-
ops = self.two_qubit_paulis
90-
return (StochasticUnitaryChannelRuntime(ops, ps),)
51+
ps = [p / 15.0] * 15
52+
controls = frame.get(stmt.controls)
53+
targets = frame.get(stmt.targets)
54+
self.apply_two_qubit_pauli_error(interp, ps, controls, targets)
9155

9256
@interp.impl(SingleQubitPauliChannel)
9357
def single_qubit_pauli_channel(
@@ -96,9 +60,11 @@ def single_qubit_pauli_channel(
9660
frame: interp.Frame,
9761
stmt: SingleQubitPauliChannel,
9862
):
99-
ps = frame.get(stmt.params)
100-
ops = self.single_qubit_paulis
101-
return (StochasticUnitaryChannelRuntime(ops, ps),)
63+
px = frame.get(stmt.px)
64+
py = frame.get(stmt.py)
65+
pz = frame.get(stmt.pz)
66+
qubits = frame.get(stmt.qubits)
67+
self.apply_single_qubit_pauli_error(interp, [px, py, pz], qubits)
10268

10369
@interp.impl(TwoQubitPauliChannel)
10470
def two_qubit_pauli_channel(
@@ -107,43 +73,54 @@ def two_qubit_pauli_channel(
10773
frame: interp.Frame,
10874
stmt: TwoQubitPauliChannel,
10975
):
110-
ps = frame.get(stmt.params)
111-
ops = self.two_qubit_paulis
112-
return (StochasticUnitaryChannelRuntime(ops, ps),)
113-
114-
@interp.impl(StochasticUnitaryChannel)
115-
def stochastic_unitary_channel(
116-
self,
117-
interp: PyQrackInterpreter,
118-
frame: interp.Frame,
119-
stmt: StochasticUnitaryChannel,
120-
):
121-
operators = frame.get(stmt.operators)
122-
probabilities = frame.get(stmt.probabilities)
123-
124-
return (StochasticUnitaryChannelRuntime(operators, probabilities),)
76+
ps = frame.get(stmt.probabilities)
77+
controls = frame.get(stmt.controls)
78+
targets = frame.get(stmt.targets)
79+
self.apply_two_qubit_pauli_error(interp, ps, controls, targets)
12580

12681
@interp.impl(QubitLoss)
12782
def qubit_loss(
12883
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: QubitLoss
12984
):
13085
p = frame.get(stmt.p)
131-
return (QubitLossRuntime(p),)
86+
qubits: list[PyQrackQubit] = frame.get(stmt.qubits)
87+
for qbit in qubits:
88+
if interp.rng_state.uniform(0.0, 1.0) <= p:
89+
qbit.drop()
13290

133-
@cached_property
134-
def single_qubit_paulis(self):
135-
return (OperatorRuntime("x"), OperatorRuntime("y"), OperatorRuntime("z"))
91+
def apply_single_qubit_pauli_error(
92+
self,
93+
interp: PyQrackInterpreter,
94+
ps: list[float],
95+
qubits: list[PyQrackQubit],
96+
):
97+
pi = 1 - sum(ps)
98+
probs = [pi] + ps
13699

137-
@cached_property
138-
def two_qubit_paulis(self):
139-
paulis = (IdentityRuntime(sites=1), *self.single_qubit_paulis)
140-
ops: list[KronRuntime] = []
141-
for idx1, pauli1 in enumerate(paulis):
142-
for idx2, pauli2 in enumerate(paulis):
143-
if idx1 == idx2 == 0:
144-
# NOTE: 'II'
145-
continue
100+
assert all(0 <= x <= 1 for x in probs), "Invalid Pauli error probabilities"
146101

147-
ops.append(KronRuntime(pauli1, pauli2))
102+
for qbit in qubits:
103+
which = interp.rng_state.choice(self.single_pauli_choices, p=probs)
104+
self.apply_pauli_error(which, qbit)
105+
106+
def apply_two_qubit_pauli_error(
107+
self,
108+
interp: PyQrackInterpreter,
109+
ps: list[float],
110+
controls: list[PyQrackQubit],
111+
targets: list[PyQrackQubit],
112+
):
113+
pii = 1 - sum(ps)
114+
probs = [pii] + ps
115+
assert all(0 <= x <= 1 for x in probs), "Invalid Pauli error probabilities"
116+
117+
for control, target in zip(controls, targets):
118+
which = interp.rng_state.choice(self.two_pauli_choices, p=probs)
119+
self.apply_pauli_error(which[0], control)
120+
self.apply_pauli_error(which[1], target)
121+
122+
def apply_pauli_error(self, which: str, qbit: PyQrackQubit):
123+
if not qbit.is_active() or which == "i":
124+
return
148125

149-
return tuple(ops)
126+
getattr(qbit.sim_reg, which)(qbit.addr)

src/bloqade/squin/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,21 @@
2929
sqrt_x as sqrt_x,
3030
sqrt_y as sqrt_y,
3131
sqrt_z as sqrt_z,
32+
bit_flip as bit_flip,
33+
depolarize as depolarize,
34+
qubit_loss as qubit_loss,
3235
sqrt_x_adj as sqrt_x_adj,
3336
sqrt_y_adj as sqrt_y_adj,
3437
sqrt_z_adj as sqrt_z_adj,
38+
depolarize2 as depolarize2,
39+
two_qubit_pauli_channel as two_qubit_pauli_channel,
40+
single_qubit_pauli_channel as single_qubit_pauli_channel,
3541
)
3642

3743
# NOTE: it's important to keep these imports here since they import squin.kernel
3844
# we skip isort here
3945
from . import parallel as parallel # isort: skip
4046
from .stdlib import ( # isort: skip
4147
gate as gate,
42-
channel as channel,
4348
broadcast as broadcast,
4449
)

src/bloqade/squin/analysis/nsites/impls.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -124,34 +124,6 @@ def two_qubit_noise(
124124
):
125125
return (NumberSites(sites=2),)
126126

127-
@interp.impl(noise.stmts.PauliError)
128-
def pauli_error(
129-
self, interp: NSitesAnalysis, frame: interp.Frame, stmt: noise.stmts.PauliError
130-
):
131-
pauli_ops_sites = frame.get(stmt.basis)
132-
return (pauli_ops_sites,)
133-
134-
@interp.impl(noise.stmts.StochasticUnitaryChannel)
135-
def stochastic_unitary_noise(
136-
self,
137-
interp: NSitesAnalysis,
138-
frame: interp.Frame,
139-
stmt: noise.stmts.StochasticUnitaryChannel,
140-
):
141-
ops = frame.get(stmt.operators)
142-
143-
# StochasticUnitaryChannel always accepts an IList of Operators
144-
# but it's the number of sites of the individual operator themselves that should
145-
# represent the sites the channel acts on.
146-
if (
147-
isinstance(ops, tuple)
148-
and all(isinstance(op, NumberSites) for op in ops)
149-
and len(set([op.sites for op in ops])) == 1
150-
):
151-
return (ops[0],)
152-
153-
return (NoSites(),)
154-
155127

156128
@scf.dialect.register(key="op.nsites")
157129
class ScfSquinOp(ScfTypeInfer):
Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,2 @@
1-
from . import stmts as stmts
1+
from . import stmts as stmts, _interface as _interface
22
from ._dialect import dialect as dialect
3-
from ._wrapper import (
4-
depolarize as depolarize,
5-
qubit_loss as qubit_loss,
6-
depolarize2 as depolarize2,
7-
pauli_error as pauli_error,
8-
two_qubit_pauli_channel as two_qubit_pauli_channel,
9-
single_qubit_pauli_channel as single_qubit_pauli_channel,
10-
stochastic_unitary_channel as stochastic_unitary_channel,
11-
)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from kirin import ir
22

3-
dialect = ir.Dialect(name="squin.noise")
3+
dialect = ir.Dialect("squin.noise")
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Any, Literal, TypeVar
2+
3+
from kirin.dialects import ilist
4+
from kirin.lowering import wraps
5+
6+
from bloqade.types import Qubit
7+
8+
from . import stmts
9+
10+
11+
@wraps(stmts.Depolarize)
12+
def depolarize(p: float, qubits: ilist.IList[Qubit, Any]) -> None: ...
13+
14+
15+
N = TypeVar("N", bound=int)
16+
17+
18+
@wraps(stmts.Depolarize2)
19+
def depolarize2(
20+
p: float, controls: ilist.IList[Qubit, N], targets: ilist.IList[Qubit, N]
21+
) -> None: ...
22+
23+
24+
@wraps(stmts.SingleQubitPauliChannel)
25+
def single_qubit_pauli_channel(
26+
px: float, py: float, pz: float, qubits: ilist.IList[Qubit, Any]
27+
) -> None: ...
28+
29+
30+
@wraps(stmts.TwoQubitPauliChannel)
31+
def two_qubit_pauli_channel(
32+
probabilities: ilist.IList[float, Literal[15]],
33+
controls: ilist.IList[Qubit, N],
34+
targets: ilist.IList[Qubit, N],
35+
) -> None: ...
36+
37+
38+
@wraps(stmts.QubitLoss)
39+
def qubit_loss(p: float, qubits: ilist.IList[Qubit, Any]) -> None: ...

0 commit comments

Comments
 (0)