Skip to content

Commit d044d1b

Browse files
authored
Implement probability checks for native noise statements (#184)
* Implement probability checks for native noise statements * Add test (marked with xfail for now) * Share check method between classes and simplify test * Raise NotImplementedError in base class and remove test execution in main
1 parent 8130a36 commit d044d1b

File tree

2 files changed

+58
-17
lines changed

2 files changed

+58
-17
lines changed

src/bloqade/noise/native/stmts.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple
2+
13
from kirin import ir, types, lowering
24
from kirin.decl import info, statement
35
from kirin.dialects import ilist
@@ -7,25 +9,44 @@
79
from ._dialect import dialect
810

911

10-
@statement(dialect=dialect)
11-
class PauliChannel(ir.Statement):
12-
12+
@statement
13+
class NativeNoiseStmt(ir.Statement):
1314
traits = frozenset({lowering.FromPythonCall()})
1415

16+
@property
17+
def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
18+
raise NotImplementedError(f"Override the method in {type(self).__name__}")
19+
20+
def check(self):
21+
for probs in self.probabilities:
22+
self.check_probability(sum(probs))
23+
for p in probs:
24+
self.check_probability(p)
25+
26+
def check_probability(self, p: float):
27+
if not 0 <= p <= 1:
28+
raise ValueError(
29+
f"Invalid noise probability encountered in {type(self).__name__}: {p}"
30+
)
31+
32+
33+
@statement(dialect=dialect)
34+
class PauliChannel(NativeNoiseStmt):
1535
px: float = info.attribute(types.Float)
1636
py: float = info.attribute(types.Float)
1737
pz: float = info.attribute(types.Float)
1838
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType])
1939

40+
@property
41+
def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
42+
return ((self.px, self.py, self.pz),)
43+
2044

2145
NumQubits = types.TypeVar("NumQubits")
2246

2347

2448
@statement(dialect=dialect)
25-
class CZPauliChannel(ir.Statement):
26-
27-
traits = frozenset({lowering.FromPythonCall()})
28-
49+
class CZPauliChannel(NativeNoiseStmt):
2950
paired: bool = info.attribute(types.Bool)
3051
px_ctrl: float = info.attribute(types.Float)
3152
py_ctrl: float = info.attribute(types.Float)
@@ -36,11 +57,19 @@ class CZPauliChannel(ir.Statement):
3657
ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, NumQubits])
3758
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, NumQubits])
3859

60+
@property
61+
def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
62+
return (
63+
(self.px_ctrl, self.py_ctrl, self.pz_ctrl),
64+
(self.px_qarg, self.py_qarg, self.pz_qarg),
65+
)
3966

40-
@statement(dialect=dialect)
41-
class AtomLossChannel(ir.Statement):
42-
43-
traits = frozenset({lowering.FromPythonCall()})
4467

68+
@statement(dialect=dialect)
69+
class AtomLossChannel(NativeNoiseStmt):
4570
prob: float = info.attribute(types.Float)
4671
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType])
72+
73+
@property
74+
def probabilities(self) -> Tuple[Tuple[float, ...], ...]:
75+
return ((self.prob,),)

test/pyqrack/runtime/noise/native/test_pauli.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from unittest.mock import Mock, call
22

3+
import pytest
34
from kirin import ir
45

56
from bloqade import qasm2
@@ -41,6 +42,23 @@ def test_atom_loss():
4142
sim_reg.assert_has_calls([call.y(0)])
4243

4344

45+
@pytest.mark.xfail
46+
def test_pauli_probs_check():
47+
@simulation
48+
def test_atom_loss():
49+
q = qasm2.qreg(2)
50+
native.pauli_channel(
51+
[q[0]],
52+
px=0.1,
53+
py=0.4,
54+
pz=1.3,
55+
)
56+
return q
57+
58+
with pytest.raises(ir.ValidationError):
59+
test_atom_loss.verify()
60+
61+
4462
def test_cz_pauli_channel_false():
4563
@simulation
4664
def test_atom_loss():
@@ -122,9 +140,3 @@ def test_atom_loss():
122140
sim_reg = run_mock(test_atom_loss, rng_state)
123141

124142
sim_reg.assert_has_calls([call.y(0), call.x(1), call.mcz([0], 1)])
125-
126-
127-
if __name__ == "__main__":
128-
test_pauli_channel()
129-
test_cz_pauli_channel_false()
130-
test_cz_pauli_channel_true()

0 commit comments

Comments
 (0)