Skip to content

Commit f3cac3e

Browse files
committed
Share check method between classes and simplify test
1 parent 3284632 commit f3cac3e

File tree

3 files changed

+55
-132
lines changed

3 files changed

+55
-132
lines changed

src/bloqade/noise/native/stmts.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple
2+
13
from kirin import ir, types, lowering
24
from kirin.decl import info, statement
35
from kirin.dialects import ilist
@@ -7,30 +9,43 @@
79
from ._dialect import dialect
810

911

10-
@statement(dialect=dialect)
11-
class PauliChannel(ir.Statement):
12-
12+
@statement
13+
class NativeNoiseStmt(ir.Statement):
1314
traits = frozenset({lowering.FromPythonCall()})
1415

16+
@property
17+
def probabilities(self) -> Tuple[Tuple[float, ...], ...]: ...
18+
19+
def check(self):
20+
for probs in self.probabilities:
21+
self.check_probability(sum(probs))
22+
for p in probs:
23+
self.check_probability(p)
24+
25+
def check_probability(self, p: float):
26+
if not 0 <= p <= 1:
27+
raise ValueError(
28+
f"Invalid noise probability encountered in {type(self).__name__}: {p}"
29+
)
30+
31+
32+
@statement(dialect=dialect)
33+
class PauliChannel(NativeNoiseStmt):
1534
px: float = info.attribute(types.Float)
1635
py: float = info.attribute(types.Float)
1736
pz: float = info.attribute(types.Float)
1837
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType])
1938

20-
def check(self):
21-
probs = (self.px, self.py, self.pz)
22-
if not all(0 <= p <= 1 for p in probs) or not 0 <= sum(probs) <= 1:
23-
raise ValueError(f"Invalid Pauli error probabilities (px, py, pz): {probs}")
39+
@property
40+
def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
41+
return ((self.px, self.py, self.pz),)
2442

2543

2644
NumQubits = types.TypeVar("NumQubits")
2745

2846

2947
@statement(dialect=dialect)
30-
class CZPauliChannel(ir.Statement):
31-
32-
traits = frozenset({lowering.FromPythonCall()})
33-
48+
class CZPauliChannel(NativeNoiseStmt):
3449
paired: bool = info.attribute(types.Bool)
3550
px_ctrl: float = info.attribute(types.Float)
3651
py_ctrl: float = info.attribute(types.Float)
@@ -41,32 +56,19 @@ class CZPauliChannel(ir.Statement):
4156
ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, NumQubits])
4257
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, NumQubits])
4358

44-
def check(self):
45-
probs_ctrl = (self.px_ctrl, self.py_ctrl, self.pz_ctrl)
46-
47-
def check_prob(p: float) -> bool:
48-
return 0 <= p <= 1
49-
50-
if not map(check_prob, probs_ctrl) or not check_prob(sum(probs_ctrl)):
51-
raise ValueError(
52-
f"Invalid control probabilities for CZ Pauli channel (px_ctrl, py_ctrl, pz_ctrl): {probs_ctrl}"
53-
)
54-
55-
probs_qarg = (self.px_qarg, self.py_qarg, self.pz_qarg)
56-
if not map(check_prob, probs_qarg) or not check_prob(sum(probs_qarg)):
57-
raise ValueError(
58-
f"Invalid probabilities for CZ Pauli channel (px_qarg, py_qarg, pz_qarg): {probs_qarg}"
59-
)
59+
@property
60+
def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
61+
return (
62+
(self.px_ctrl, self.py_ctrl, self.pz_ctrl),
63+
(self.px_qarg, self.py_qarg, self.pz_qarg),
64+
)
6065

6166

6267
@statement(dialect=dialect)
63-
class AtomLossChannel(ir.Statement):
64-
65-
traits = frozenset({lowering.FromPythonCall()})
66-
68+
class AtomLossChannel(NativeNoiseStmt):
6769
prob: float = info.attribute(types.Float)
6870
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType])
6971

70-
def check(self):
71-
if not 0 <= self.prob <= 1:
72-
raise ValueError(f"Invalid atom loss probability {self.prob}")
72+
@property
73+
def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
74+
return ((self.prob,),)
Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
from typing import Literal
2-
import textwrap
32
from unittest.mock import Mock
43

5-
import pytest
64
from kirin import ir
75
from kirin.dialects import ilist
86

97
from bloqade import qasm2
108
from bloqade.noise import native
119
from bloqade.pyqrack import PyQrackQubit, PyQrackInterpreter, reg
1210
from bloqade.pyqrack.base import MockMemory
13-
from bloqade.qasm2.passes import QASM2Py, NoisePass, QASM2Fold
14-
from bloqade.qasm2.parse.lowering import QASM2
1511

1612
simulation = qasm2.extended.add(native)
1713

@@ -47,97 +43,3 @@ def test_atom_loss(c: qasm2.CReg):
4743
assert result[0].state is reg.QubitState.Lost
4844
assert result[1].state is reg.QubitState.Active
4945
assert input[0] is reg.Measurement.One
50-
51-
52-
@pytest.mark.xfail
53-
def test_noise_probs():
54-
test_qasm = textwrap.dedent(
55-
"""
56-
OPENQASM 2.0;
57-
include "qelib1.inc";
58-
59-
// Qubits: [q_0, q_1, q_2, q_3, q_4, q_5]
60-
qreg q[6];
61-
62-
63-
u3(pi*0.9999896015,pi*1.8867094803,pi*0.1132905197) q[2];
64-
u3(pi*1.499959526,pi*1.2634437582,pi*0.7365562418) q[3];
65-
u3(pi*1.4998447568,pi*1.8205928898,pi*0.1794071102) q[4];
66-
u3(pi*1.4998052589,pi*1.5780611154,pi*0.4219388846) q[5];
67-
u3(pi*0.4920440401,pi*1.287644074,pi*0.712355926) q[0];
68-
u3(pi*1.0012473155,pi*1.3019213156,pi*0.6980786844) q[1];
69-
cz q[1],q[2];
70-
cz q[1],q[2];
71-
cz q[2],q[3];
72-
cz q[4],q[5];
73-
cz q[2],q[3];
74-
cz q[4],q[5];
75-
cz q[2],q[3];
76-
cz q[4],q[5];
77-
cz q[2],q[3];
78-
cz q[4],q[5];
79-
cz q[2],q[3];
80-
cz q[4],q[5];
81-
u3(pi*1.0,pi*1.5687764466,pi*0.4312235534) q[2];
82-
u3(pi*0.5,0,pi*1.7365086077) q[3];
83-
u3(pi*0.5,pi*1.0,pi*0.6112880576) q[4];
84-
u3(pi*0.1388474164,pi*1.7687898606,pi*1.2425564668) q[5];
85-
"""
86-
)
87-
88-
entry = QASM2(qasm2.main.add(qasm2.inline_)).loads(test_qasm, "entry", returns="q")
89-
QASM2Py(entry.dialects)(entry)
90-
entry = entry.similar(qasm2.extended.add(native))
91-
QASM2Fold(entry.dialects).fixpoint(entry)
92-
93-
# Noise parameters
94-
gate_noise_value = 1e-3
95-
move_noise_value = 0.5
96-
97-
gate_noise_params = native.GateNoiseParams(
98-
local_px=gate_noise_value,
99-
local_py=gate_noise_value,
100-
local_pz=gate_noise_value,
101-
local_loss_prob=gate_noise_value,
102-
#
103-
global_px=gate_noise_value,
104-
global_py=gate_noise_value,
105-
global_pz=gate_noise_value,
106-
global_loss_prob=gate_noise_value,
107-
#
108-
cz_paired_gate_px=gate_noise_value,
109-
cz_paired_gate_py=gate_noise_value,
110-
cz_paired_gate_pz=gate_noise_value,
111-
cz_gate_loss_prob=gate_noise_value,
112-
#
113-
cz_unpaired_gate_px=gate_noise_value,
114-
cz_unpaired_gate_py=gate_noise_value,
115-
cz_unpaired_gate_pz=gate_noise_value,
116-
cz_unpaired_loss_prob=gate_noise_value,
117-
)
118-
119-
move_noise_params = native.model.MoveNoiseParams(
120-
idle_px_rate=move_noise_value,
121-
idle_py_rate=move_noise_value,
122-
idle_pz_rate=move_noise_value,
123-
idle_loss_rate=move_noise_value,
124-
move_px_rate=move_noise_value,
125-
move_py_rate=move_noise_value,
126-
move_pz_rate=move_noise_value,
127-
move_loss_rate=move_noise_value,
128-
#
129-
pick_px=move_noise_value,
130-
pick_py=move_noise_value,
131-
pick_pz=move_noise_value,
132-
pick_loss_prob=move_noise_value,
133-
#
134-
move_speed=5e-1, # default 5e-1
135-
storage_spacing=4.0, # default 4.0
136-
)
137-
138-
with pytest.raises(ir.ValidationError):
139-
NoisePass(
140-
entry.dialects,
141-
gate_noise_params=gate_noise_params,
142-
noise_model=native.TwoRowZoneModel(params=move_noise_params),
143-
)(entry)

test/pyqrack/runtime/noise/native/test_pauli.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from unittest.mock import Mock, call
22

3+
import pytest
34
from kirin import ir
45

56
from bloqade import qasm2
@@ -41,6 +42,23 @@ def test_atom_loss():
4142
sim_reg.assert_has_calls([call.y(0)])
4243

4344

45+
@pytest.mark.xfail
46+
def test_pauli_probs_check():
47+
@simulation
48+
def test_atom_loss():
49+
q = qasm2.qreg(2)
50+
native.pauli_channel(
51+
[q[0]],
52+
px=0.1,
53+
py=0.4,
54+
pz=1.3,
55+
)
56+
return q
57+
58+
with pytest.raises(ir.ValidationError):
59+
test_atom_loss.verify()
60+
61+
4462
def test_cz_pauli_channel_false():
4563
@simulation
4664
def test_atom_loss():
@@ -126,5 +144,6 @@ def test_atom_loss():
126144

127145
if __name__ == "__main__":
128146
test_pauli_channel()
147+
test_pauli_probs_check()
129148
test_cz_pauli_channel_false()
130149
test_cz_pauli_channel_true()

0 commit comments

Comments
 (0)