diff --git a/src/bloqade/noise/native/stmts.py b/src/bloqade/noise/native/stmts.py index d969e71c..1a840b3e 100644 --- a/src/bloqade/noise/native/stmts.py +++ b/src/bloqade/noise/native/stmts.py @@ -1,3 +1,5 @@ +from typing import Tuple + from kirin import ir, types, lowering from kirin.decl import info, statement from kirin.dialects import ilist @@ -7,25 +9,44 @@ from ._dialect import dialect -@statement(dialect=dialect) -class PauliChannel(ir.Statement): - +@statement +class NativeNoiseStmt(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) + @property + def probabilities(self) -> Tuple[Tuple[float, ...], ...]: + raise NotImplementedError(f"Override the method in {type(self).__name__}") + + def check(self): + for probs in self.probabilities: + self.check_probability(sum(probs)) + for p in probs: + self.check_probability(p) + + def check_probability(self, p: float): + if not 0 <= p <= 1: + raise ValueError( + f"Invalid noise probability encountered in {type(self).__name__}: {p}" + ) + + +@statement(dialect=dialect) +class PauliChannel(NativeNoiseStmt): px: float = info.attribute(types.Float) py: float = info.attribute(types.Float) pz: float = info.attribute(types.Float) qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType]) + @property + def probabilities(self) -> Tuple[Tuple[float, ...], ...]: + return ((self.px, self.py, self.pz),) + NumQubits = types.TypeVar("NumQubits") @statement(dialect=dialect) -class CZPauliChannel(ir.Statement): - - traits = frozenset({lowering.FromPythonCall()}) - +class CZPauliChannel(NativeNoiseStmt): paired: bool = info.attribute(types.Bool) px_ctrl: float = info.attribute(types.Float) py_ctrl: float = info.attribute(types.Float) @@ -36,11 +57,19 @@ class CZPauliChannel(ir.Statement): ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, NumQubits]) qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, NumQubits]) + @property + def probabilities(self) -> Tuple[Tuple[float, ...], ...]: + return ( + (self.px_ctrl, self.py_ctrl, self.pz_ctrl), + (self.px_qarg, self.py_qarg, self.pz_qarg), + ) -@statement(dialect=dialect) -class AtomLossChannel(ir.Statement): - - traits = frozenset({lowering.FromPythonCall()}) +@statement(dialect=dialect) +class AtomLossChannel(NativeNoiseStmt): prob: float = info.attribute(types.Float) qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType]) + + @property + def probabilities(self) -> Tuple[Tuple[float, ...], ...]: + return ((self.prob,),) diff --git a/test/pyqrack/runtime/noise/native/test_pauli.py b/test/pyqrack/runtime/noise/native/test_pauli.py index 34ed839f..d67e24d1 100644 --- a/test/pyqrack/runtime/noise/native/test_pauli.py +++ b/test/pyqrack/runtime/noise/native/test_pauli.py @@ -1,5 +1,6 @@ from unittest.mock import Mock, call +import pytest from kirin import ir from bloqade import qasm2 @@ -41,6 +42,23 @@ def test_atom_loss(): sim_reg.assert_has_calls([call.y(0)]) +@pytest.mark.xfail +def test_pauli_probs_check(): + @simulation + def test_atom_loss(): + q = qasm2.qreg(2) + native.pauli_channel( + [q[0]], + px=0.1, + py=0.4, + pz=1.3, + ) + return q + + with pytest.raises(ir.ValidationError): + test_atom_loss.verify() + + def test_cz_pauli_channel_false(): @simulation def test_atom_loss(): @@ -122,9 +140,3 @@ def test_atom_loss(): sim_reg = run_mock(test_atom_loss, rng_state) sim_reg.assert_has_calls([call.y(0), call.x(1), call.mcz([0], 1)]) - - -if __name__ == "__main__": - test_pauli_channel() - test_cz_pauli_channel_false() - test_cz_pauli_channel_true()