diff --git a/src/bloqade/pyqrack/__init__.py b/src/bloqade/pyqrack/__init__.py index 27fd3e54..3664fda6 100644 --- a/src/bloqade/pyqrack/__init__.py +++ b/src/bloqade/pyqrack/__init__.py @@ -3,6 +3,7 @@ CRegister as CRegister, QubitState as QubitState, Measurement as Measurement, + PyQrackWire as PyQrackWire, PyQrackQubit as PyQrackQubit, ) from .base import ( @@ -14,4 +15,5 @@ # NOTE: The following import is for registering the method tables from .noise import native as native from .qasm2 import uop as uop, core as core, glob as glob, parallel as parallel +from .squin import op as op, qubit as qubit from .target import PyQrack as PyQrack diff --git a/src/bloqade/pyqrack/reg.py b/src/bloqade/pyqrack/reg.py index 7f498776..644f3859 100644 --- a/src/bloqade/pyqrack/reg.py +++ b/src/bloqade/pyqrack/reg.py @@ -70,3 +70,8 @@ def is_active(self) -> bool: def drop(self): """Drop the qubit in-place.""" self.state = QubitState.Lost + + +@dataclass +class PyQrackWire: + qubit: PyQrackQubit diff --git a/src/bloqade/pyqrack/squin/__init__.py b/src/bloqade/pyqrack/squin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/pyqrack/squin/op.py b/src/bloqade/pyqrack/squin/op.py new file mode 100644 index 00000000..738ec9b4 --- /dev/null +++ b/src/bloqade/pyqrack/squin/op.py @@ -0,0 +1,154 @@ +from kirin import interp + +from bloqade.squin import op +from bloqade.pyqrack.base import PyQrackInterpreter + +from .runtime import ( + SnRuntime, + SpRuntime, + U3Runtime, + RotRuntime, + KronRuntime, + MultRuntime, + ScaleRuntime, + AdjointRuntime, + ControlRuntime, + PhaseOpRuntime, + IdentityRuntime, + OperatorRuntime, + ProjectorRuntime, + OperatorRuntimeABC, + PauliStringRuntime, +) + + +@op.dialect.register(key="pyqrack") +class PyQrackMethods(interp.MethodTable): + + @interp.impl(op.stmts.Kron) + def kron( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Kron + ) -> tuple[OperatorRuntimeABC]: + lhs = frame.get(stmt.lhs) + rhs = frame.get(stmt.rhs) + return (KronRuntime(lhs, rhs),) + + @interp.impl(op.stmts.Mult) + def mult( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Mult + ) -> tuple[OperatorRuntimeABC]: + lhs = frame.get(stmt.lhs) + rhs = frame.get(stmt.rhs) + return (MultRuntime(lhs, rhs),) + + @interp.impl(op.stmts.Adjoint) + def adjoint( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Adjoint + ) -> tuple[OperatorRuntimeABC]: + op = frame.get(stmt.op) + return (AdjointRuntime(op),) + + @interp.impl(op.stmts.Scale) + def scale( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Scale + ) -> tuple[OperatorRuntimeABC]: + op = frame.get(stmt.op) + factor = frame.get(stmt.factor) + return (ScaleRuntime(op, factor),) + + @interp.impl(op.stmts.Control) + def control( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Control + ) -> tuple[OperatorRuntimeABC]: + op = frame.get(stmt.op) + n_controls = stmt.n_controls + rt = ControlRuntime( + op=op, + n_controls=n_controls, + ) + return (rt,) + + @interp.impl(op.stmts.Rot) + def rot( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Rot + ) -> tuple[OperatorRuntimeABC]: + axis = frame.get(stmt.axis) + angle = frame.get(stmt.angle) + return (RotRuntime(axis, angle),) + + @interp.impl(op.stmts.Identity) + def identity( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Identity + ) -> tuple[OperatorRuntimeABC]: + return (IdentityRuntime(sites=stmt.sites),) + + @interp.impl(op.stmts.PhaseOp) + @interp.impl(op.stmts.ShiftOp) + def phaseop( + self, + interp: PyQrackInterpreter, + frame: interp.Frame, + stmt: op.stmts.PhaseOp | op.stmts.ShiftOp, + ) -> tuple[OperatorRuntimeABC]: + theta = frame.get(stmt.theta) + global_ = isinstance(stmt, op.stmts.PhaseOp) + return (PhaseOpRuntime(theta, global_=global_),) + + @interp.impl(op.stmts.X) + @interp.impl(op.stmts.Y) + @interp.impl(op.stmts.Z) + @interp.impl(op.stmts.H) + @interp.impl(op.stmts.S) + @interp.impl(op.stmts.T) + def operator( + self, + interp: PyQrackInterpreter, + frame: interp.Frame, + stmt: ( + op.stmts.X | op.stmts.Y | op.stmts.Z | op.stmts.H | op.stmts.S | op.stmts.T + ), + ) -> tuple[OperatorRuntimeABC]: + return (OperatorRuntime(method_name=stmt.name.lower()),) + + @interp.impl(op.stmts.P0) + @interp.impl(op.stmts.P1) + def projector( + self, + interp: PyQrackInterpreter, + frame: interp.Frame, + stmt: op.stmts.P0 | op.stmts.P1, + ) -> tuple[OperatorRuntimeABC]: + state = isinstance(stmt, op.stmts.P1) + return (ProjectorRuntime(to_state=state),) + + @interp.impl(op.stmts.Sp) + def sp( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Sp + ) -> tuple[OperatorRuntimeABC]: + return (SpRuntime(),) + + @interp.impl(op.stmts.Sn) + def sn( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Sn + ) -> tuple[OperatorRuntimeABC]: + return (SnRuntime(),) + + @interp.impl(op.stmts.U3) + def u3( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.U3 + ) -> tuple[OperatorRuntimeABC]: + theta = frame.get(stmt.theta) + phi = frame.get(stmt.phi) + lam = frame.get(stmt.lam) + return (U3Runtime(theta, phi, lam),) + + @interp.impl(op.stmts.PauliString) + def clifford_string( + self, + interp: PyQrackInterpreter, + frame: interp.Frame, + stmt: op.stmts.PauliString, + ) -> tuple[OperatorRuntimeABC]: + string = stmt.string + ops = [OperatorRuntime(method_name=name.lower()) for name in stmt.string] + return (PauliStringRuntime(string, ops),) diff --git a/src/bloqade/pyqrack/squin/qubit.py b/src/bloqade/pyqrack/squin/qubit.py new file mode 100644 index 00000000..a0fe0a89 --- /dev/null +++ b/src/bloqade/pyqrack/squin/qubit.py @@ -0,0 +1,85 @@ +from typing import Any + +from kirin import interp +from kirin.dialects import ilist + +from bloqade.squin import qubit +from bloqade.pyqrack.reg import QubitState, PyQrackQubit +from bloqade.pyqrack.base import PyQrackInterpreter + +from .runtime import OperatorRuntimeABC + + +@qubit.dialect.register(key="pyqrack") +class PyQrackMethods(interp.MethodTable): + @interp.impl(qubit.New) + def new(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New): + n_qubits: int = frame.get(stmt.n_qubits) + qreg = ilist.IList( + [ + PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active) + for i in interp.memory.allocate(n_qubits=n_qubits) + ] + ) + return (qreg,) + + @interp.impl(qubit.Apply) + def apply(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Apply): + qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits) + operator: OperatorRuntimeABC = frame.get(stmt.operator) + operator.apply(*qubits) + + @interp.impl(qubit.Broadcast) + def broadcast( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Broadcast + ): + operator: OperatorRuntimeABC = frame.get(stmt.operator) + qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits) + operator.broadcast_apply(qubits) + + def _measure_qubit(self, qbit: PyQrackQubit): + if qbit.is_active(): + return bool(qbit.sim_reg.m(qbit.addr)) + + @interp.impl(qubit.MeasureQubitList) + def measure_qubit_list( + self, + interp: PyQrackInterpreter, + frame: interp.Frame, + stmt: qubit.MeasureQubitList, + ): + qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits) + result = ilist.IList([self._measure_qubit(qbit) for qbit in qubits]) + return (result,) + + @interp.impl(qubit.MeasureQubit) + def measure_qubit( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasureQubit + ): + qbit: PyQrackQubit = frame.get(stmt.qubit) + result = self._measure_qubit(qbit) + return (result,) + + @interp.impl(qubit.MeasureAndReset) + def measure_and_reset( + self, + interp: PyQrackInterpreter, + frame: interp.Frame, + stmt: qubit.MeasureAndReset, + ): + qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits) + result = [] + for qbit in qubits: + if qbit.is_active(): + result.append(qbit.sim_reg.m(qbit.addr)) + else: + result.append(None) + qbit.sim_reg.force_m(qbit.addr, 0) + + return (ilist.IList(result),) + + @interp.impl(qubit.Reset) + def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Reset): + qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits) + for qbit in qubits: + qbit.sim_reg.force_m(qbit.addr, 0) diff --git a/src/bloqade/pyqrack/squin/runtime.py b/src/bloqade/pyqrack/squin/runtime.py new file mode 100644 index 00000000..d14d0c55 --- /dev/null +++ b/src/bloqade/pyqrack/squin/runtime.py @@ -0,0 +1,515 @@ +from typing import Any +from dataclasses import field, dataclass + +import numpy as np +from kirin.dialects import ilist + +from pyqrack.pauli import Pauli +from bloqade.pyqrack import PyQrackQubit + + +@dataclass(frozen=True) +class OperatorRuntimeABC: + """The number of sites the operator applies to (including controls)""" + + @property + def n_sites(self) -> int: ... + + def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None: + raise NotImplementedError( + "Operator runtime base class should not be called directly, override the method" + ) + + def control_apply( + self, + controls: tuple[PyQrackQubit, ...], + targets: tuple[PyQrackQubit, ...], + adjoint: bool = False, + ) -> None: + raise RuntimeError(f"Can't apply controlled version of {self}") + + def broadcast_apply(self, qubits: ilist.IList[PyQrackQubit, Any], **kwargs) -> None: + n = self.n_sites + + if len(qubits) % n != 0: + raise RuntimeError( + f"Cannot broadcast operator {self} that applies to {n} over {len(qubits)} qubits." + ) + + for qubit_index in range(0, len(qubits), n): + targets = qubits[qubit_index : qubit_index + n] + self.apply(*targets, **kwargs) + + +@dataclass(frozen=True) +class OperatorRuntime(OperatorRuntimeABC): + method_name: str + + @property + def n_sites(self) -> int: + return 1 + + def get_method_name(self, adjoint: bool, control: bool) -> str: + method_name = "" + if control: + method_name += "mc" + + if adjoint and self.method_name in ("s", "t"): + method_name += "adj" + + return method_name + self.method_name + + def apply(self, qubit: PyQrackQubit, adjoint: bool = False) -> None: + if not qubit.is_active(): + return + method_name = self.get_method_name(adjoint=adjoint, control=False) + getattr(qubit.sim_reg, method_name)(qubit.addr) + + def control_apply( + self, + controls: tuple[PyQrackQubit, ...], + targets: tuple[PyQrackQubit], + adjoint: bool = False, + ) -> None: + target = targets[0] + if not target.is_active(): + return + + ctrls: list[int] = [] + for qbit in controls: + if not qbit.is_active(): + return + + ctrls.append(qbit.addr) + + method_name = self.get_method_name(adjoint=adjoint, control=True) + getattr(target.sim_reg, method_name)(ctrls, target.addr) + + +@dataclass(frozen=True) +class ControlRuntime(OperatorRuntimeABC): + op: OperatorRuntimeABC + n_controls: int + + @property + def n_sites(self) -> int: + return self.op.n_sites + self.n_controls + + def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None: + ctrls = qubits[: self.n_controls] + targets = qubits[self.n_controls :] + + if len(targets) != self.op.n_sites: + raise RuntimeError( + f"Cannot apply operator {self.op} to {len(targets)} qubits! It applies to {self.op.n_sites}, check your inputs!" + ) + + self.op.control_apply(controls=ctrls, targets=targets, adjoint=adjoint) + + +@dataclass(frozen=True) +class ProjectorRuntime(OperatorRuntimeABC): + to_state: bool + + @property + def n_sites(self) -> int: + return 1 + + def apply(self, qubit: PyQrackQubit, adjoint: bool = False) -> None: + if not qubit.is_active(): + return + qubit.sim_reg.force_m(qubit.addr, self.to_state) + + def control_apply( + self, + controls: tuple[PyQrackQubit, ...], + targets: tuple[PyQrackQubit], + adjoint: bool = False, + ) -> None: + target = targets[0] + if not target.is_active(): + return + + ctrls: list[int] = [] + for qbit in controls: + if not qbit.is_active(): + return + + m = [not self.to_state, 0, 0, self.to_state] + target.sim_reg.mcmtrx(ctrls, m, target.addr) + + +@dataclass(frozen=True) +class IdentityRuntime(OperatorRuntimeABC): + # TODO: do we even need sites? The apply never does anything + sites: int + + @property + def n_sites(self) -> int: + return self.sites + + def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None: + pass + + def control_apply( + self, + controls: tuple[PyQrackQubit, ...], + targets: tuple[PyQrackQubit, ...], + adjoint: bool = False, + ) -> None: + pass + + +@dataclass(frozen=True) +class MultRuntime(OperatorRuntimeABC): + lhs: OperatorRuntimeABC + rhs: OperatorRuntimeABC + + @property + def n_sites(self) -> int: + if self.lhs.n_sites != self.rhs.n_sites: + raise RuntimeError("Multiplication of operators with unequal size.") + + return self.lhs.n_sites + + def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None: + if adjoint: + # NOTE: inverted order + self.lhs.apply(*qubits, adjoint=adjoint) + self.rhs.apply(*qubits, adjoint=adjoint) + else: + self.rhs.apply(*qubits) + self.lhs.apply(*qubits) + + def control_apply( + self, + controls: tuple[PyQrackQubit, ...], + targets: tuple[PyQrackQubit, ...], + adjoint: bool = False, + ) -> None: + if adjoint: + self.lhs.control_apply(controls=controls, targets=targets, adjoint=adjoint) + self.rhs.control_apply(controls=controls, targets=targets, adjoint=adjoint) + else: + self.rhs.control_apply(controls=controls, targets=targets, adjoint=adjoint) + self.lhs.control_apply(controls=controls, targets=targets, adjoint=adjoint) + + +@dataclass(frozen=True) +class KronRuntime(OperatorRuntimeABC): + lhs: OperatorRuntimeABC + rhs: OperatorRuntimeABC + + @property + def n_sites(self) -> int: + return self.lhs.n_sites + self.rhs.n_sites + + def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None: + self.lhs.apply(*qubits[: self.lhs.n_sites], adjoint=adjoint) + self.rhs.apply(*qubits[self.lhs.n_sites :], adjoint=adjoint) + + def control_apply( + self, + controls: tuple[PyQrackQubit, ...], + targets: tuple[PyQrackQubit, ...], + adjoint: bool = False, + ) -> None: + self.lhs.control_apply( + controls=controls, + targets=tuple(targets[: self.lhs.n_sites]), + adjoint=adjoint, + ) + self.rhs.control_apply( + controls=controls, + targets=tuple(targets[self.lhs.n_sites :]), + adjoint=adjoint, + ) + + +@dataclass(frozen=True) +class ScaleRuntime(OperatorRuntimeABC): + op: OperatorRuntimeABC + factor: complex + + @property + def n_sites(self) -> int: + return self.op.n_sites + + @staticmethod + def mat(factor, adjoint: bool): + if adjoint: + return [np.conj(factor), 0, 0, factor] + else: + return [factor, 0, 0, factor] + + def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None: + self.op.apply(*qubits, adjoint=adjoint) + + # NOTE: when applying to multiple qubits, we "spread" the factor evenly + applied_factor = self.factor ** (1.0 / len(qubits)) + for qbit in qubits: + if not qbit.is_active(): + continue + + # NOTE: just factor * eye(2) + m = self.mat(applied_factor, adjoint) + + # TODO: output seems to always be normalized -- no-op? + qbit.sim_reg.mtrx(m, qbit.addr) + + def control_apply( + self, + controls: tuple[PyQrackQubit, ...], + targets: tuple[PyQrackQubit, ...], + adjoint: bool = False, + ) -> None: + + ctrls: list[int] = [] + for qbit in controls: + if not qbit.is_active(): + return + + ctrls.append(qbit.addr) + + self.op.control_apply(controls=controls, targets=targets, adjoint=adjoint) + + applied_factor = self.factor ** (1.0 / len(targets)) + for target in targets: + m = self.mat(applied_factor, adjoint=adjoint) + target.sim_reg.mcmtrx(ctrls, m, target.addr) + + +@dataclass(frozen=True) +class MtrxOpRuntime(OperatorRuntimeABC): + def mat(self, adjoint: bool) -> list[complex]: + raise NotImplementedError("Override this method in the subclass!") + + @property + def n_sites(self) -> int: + # NOTE: pyqrack only supports 2x2 matrices, i.e. single qubit applications + return 1 + + def apply(self, target: PyQrackQubit, adjoint: bool = False) -> None: + if not target.is_active(): + return + + m = self.mat(adjoint=adjoint) + target.sim_reg.mtrx(m, target.addr) + + def control_apply( + self, + controls: tuple[PyQrackQubit, ...], + targets: tuple[PyQrackQubit, ...], + adjoint: bool = False, + ) -> None: + target = targets[0] + if not target.is_active(): + return + + ctrls: list[int] = [] + for qbit in controls: + if not qbit.is_active(): + return + + ctrls.append(qbit.addr) + + m = self.mat(adjoint=adjoint) + target.sim_reg.mcmtrx(ctrls, m, target.addr) + + +@dataclass(frozen=True) +class SpRuntime(MtrxOpRuntime): + def mat(self, adjoint: bool) -> list[complex]: + if adjoint: + return [0, 0, 0.5, 0] + else: + return [0, 0.5, 0, 0] + + +@dataclass(frozen=True) +class SnRuntime(MtrxOpRuntime): + def mat(self, adjoint: bool) -> list[complex]: + if adjoint: + return [0, 0.5, 0, 0] + else: + return [0, 0, 0.5, 0] + + +@dataclass(frozen=True) +class PhaseOpRuntime(MtrxOpRuntime): + theta: float + global_: bool + + def mat(self, adjoint: bool) -> list[complex]: + sign = (-1) ** (not adjoint) + local_phase = np.exp(sign * 1j * self.theta) + + # NOTE: this is just 1 if we want a local shift + global_phase = np.exp(sign * 1j * self.theta * self.global_) + + return [global_phase, 0, 0, local_phase] + + +@dataclass(frozen=True) +class RotRuntime(OperatorRuntimeABC): + axis: OperatorRuntimeABC + angle: float + pyqrack_axis: Pauli = field(init=False) + + @property + def n_sites(self) -> int: + return 1 + + def __post_init__(self): + if not isinstance(self.axis, OperatorRuntime): + raise RuntimeError( + f"Rotation only supported for Pauli operators! Got {self.axis}" + ) + + try: + axis = getattr(Pauli, "Pauli" + self.axis.method_name.upper()) + except KeyError: + raise RuntimeError( + f"Rotation only supported for Pauli operators! Got {self.axis}" + ) + + # NOTE: weird setattr for frozen dataclasses + object.__setattr__(self, "pyqrack_axis", axis) + + def apply(self, target: PyQrackQubit, adjoint: bool = False) -> None: + if not target.is_active(): + return + + sign = (-1) ** adjoint + angle = sign * self.angle + target.sim_reg.r(self.pyqrack_axis, angle, target.addr) + + def control_apply( + self, + controls: tuple[PyQrackQubit, ...], + targets: tuple[PyQrackQubit, ...], + adjoint: bool = False, + ) -> None: + target = targets[0] + if not target.is_active(): + return + + ctrls: list[int] = [] + for qbit in controls: + if not qbit.is_active(): + return + + ctrls.append(qbit.addr) + + sign = (-1) ** (not adjoint) + angle = sign * self.angle + target.sim_reg.mcr(self.pyqrack_axis, angle, ctrls, target.addr) + + +@dataclass(frozen=True) +class AdjointRuntime(OperatorRuntimeABC): + op: OperatorRuntimeABC + + @property + def n_sites(self) -> int: + return self.op.n_sites + + def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None: + # NOTE: to account for adjoint(adjoint(op)) + passed_on_adjoint = not adjoint + + self.op.apply(*qubits, adjoint=passed_on_adjoint) + + def control_apply( + self, + controls: tuple[PyQrackQubit, ...], + targets: tuple[PyQrackQubit, ...], + adjoint: bool = False, + ) -> None: + passed_on_adjoint = not adjoint + self.op.control_apply( + controls=controls, targets=targets, adjoint=passed_on_adjoint + ) + + +@dataclass(frozen=True) +class U3Runtime(OperatorRuntimeABC): + theta: float + phi: float + lam: float + + @property + def n_sites(self) -> int: + return 1 + + def angles(self, adjoint: bool) -> tuple[float, float, float]: + if adjoint: + # NOTE: adjoint(U(theta, phi, lam)) == U(-theta, -lam, -phi) + return -self.theta, -self.lam, -self.phi + else: + return self.theta, self.phi, self.lam + + def apply(self, target: PyQrackQubit, adjoint: bool = False) -> None: + if not target.is_active(): + return + + angles = self.angles(adjoint=adjoint) + target.sim_reg.u(target.addr, *angles) + + def control_apply( + self, + controls: tuple[PyQrackQubit, ...], + targets: tuple[PyQrackQubit, ...], + adjoint: bool = False, + ) -> None: + target = targets[0] + if not target.is_active(): + return + + ctrls: list[int] = [] + for qbit in controls: + if not qbit.is_active(): + return + + ctrls.append(qbit.addr) + + angles = self.angles(adjoint=adjoint) + target.sim_reg.mcu(ctrls, target.addr, *angles) + + +@dataclass(frozen=True) +class PauliStringRuntime(OperatorRuntimeABC): + string: str + ops: list[OperatorRuntime] + + @property + def n_sites(self) -> int: + return sum((op.n_sites for op in self.ops)) + + def apply(self, *qubits: PyQrackQubit, adjoint: bool = False): + if len(qubits) != self.n_sites: + raise RuntimeError( + f"Cannot apply Pauli string {self.string} to {len(qubits)} qubits! Make sure the number of qubits matches." + ) + + qubit_index = 0 + for op in self.ops: + next_qubit_index = qubit_index + op.n_sites + op.apply(*qubits[qubit_index:next_qubit_index], adjoint=adjoint) + qubit_index = next_qubit_index + + def control_apply( + self, + controls: tuple[PyQrackQubit, ...], + targets: tuple[PyQrackQubit, ...], + adjoint: bool = False, + ) -> None: + if len(targets) != self.n_sites: + raise RuntimeError( + f"Cannot apply Pauli string {self.string} to {len(targets)} qubits! Make sure the number of qubits matches." + ) + + for i, op in enumerate(self.ops): + # NOTE: this is fine as the size of each op is actually just 1 by definition + target = targets[i] + op.control_apply(controls=controls, targets=(target,)) diff --git a/src/bloqade/pyqrack/squin/wire.py b/src/bloqade/pyqrack/squin/wire.py new file mode 100644 index 00000000..c0f0b1d0 --- /dev/null +++ b/src/bloqade/pyqrack/squin/wire.py @@ -0,0 +1,69 @@ +from kirin import interp + +from bloqade.squin import wire +from bloqade.pyqrack.reg import PyQrackWire, PyQrackQubit +from bloqade.pyqrack.base import PyQrackInterpreter + +from .runtime import OperatorRuntimeABC + + +@wire.dialect.register(key="pyqrack") +class PyQrackMethods(interp.MethodTable): + # @interp.impl(wire.Wrap) + # def wrap(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: wire.Wrap): + # traits = frozenset({lowering.FromPythonCall(), WireTerminator()}) + # wire: ir.SSAValue = info.argument(WireType) + # qubit: ir.SSAValue = info.argument(QubitType) + + @interp.impl(wire.Unwrap) + def unwrap( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: wire.Unwrap + ): + q: PyQrackQubit = frame.get(stmt.qubit) + return (PyQrackWire(q),) + + @interp.impl(wire.Apply) + def apply(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: wire.Apply): + ws = stmt.inputs + assert isinstance(ws, tuple) + qubits: list[PyQrackQubit] = [] + for w in ws: + assert isinstance(w, PyQrackWire) + qubits.append(w.qubit) + op: OperatorRuntimeABC = frame.get(stmt.operator) + + op.apply(*qubits) + + out_ws = [PyQrackWire(qbit) for qbit in qubits] + return (out_ws,) + + @interp.impl(wire.Measure) + def measure( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: wire.Measure + ): + w: PyQrackWire = frame.get(stmt.wire) + qbit = w.qubit + res: int = qbit.sim_reg.m(qbit.addr) + return (res,) + + @interp.impl(wire.MeasureAndReset) + def measure_and_reset( + self, + interp: PyQrackInterpreter, + frame: interp.Frame, + stmt: wire.MeasureAndReset, + ): + w: PyQrackWire = frame.get(stmt.wire) + qbit = w.qubit + res: int = qbit.sim_reg.m(qbit.addr) + qbit.sim_reg.force_m(qbit.addr, False) + new_w = PyQrackWire(qbit) + return (new_w, res) + + @interp.impl(wire.Reset) + def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: wire.Reset): + w: PyQrackWire = frame.get(stmt.wire) + qbit = w.qubit + qbit.sim_reg.force_m(qbit.addr, False) + new_w = PyQrackWire(qbit) + return (new_w,) diff --git a/src/bloqade/squin/analysis/nsites/impls.py b/src/bloqade/squin/analysis/nsites/impls.py index 3a4f94f1..74b6b759 100644 --- a/src/bloqade/squin/analysis/nsites/impls.py +++ b/src/bloqade/squin/analysis/nsites/impls.py @@ -1,6 +1,4 @@ -from typing import cast - -from kirin import ir, interp +from kirin import interp from bloqade.squin import op @@ -52,9 +50,7 @@ def control( if isinstance(op_sites, NumberSites): n_sites = op_sites.sites - n_controls_attr = stmt.get_attr_or_prop("n_controls") - n_controls = cast(ir.PyAttr[int], n_controls_attr).data - return (NumberSites(sites=n_sites + n_controls),) + return (NumberSites(sites=n_sites + stmt.n_controls),) else: return (NoSites(),) diff --git a/src/bloqade/squin/op/__init__.py b/src/bloqade/squin/op/__init__.py index 77b07c64..998d45e0 100644 --- a/src/bloqade/squin/op/__init__.py +++ b/src/bloqade/squin/op/__init__.py @@ -8,19 +8,41 @@ @_wraps(stmts.Kron) -def kron(lhs: types.Op, rhs: types.Op, *, is_unitary: bool = False) -> types.Op: ... +def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ... + + +@_wraps(stmts.Mult) +def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ... + + +@_wraps(stmts.Scale) +def scale(op: types.Op, factor: complex) -> types.Op: ... @_wraps(stmts.Adjoint) -def adjoint(op: types.Op, *, is_unitary: bool = False) -> types.Op: ... +def adjoint(op: types.Op) -> types.Op: ... @_wraps(stmts.Control) -def control(op: types.Op, *, n_controls: int, is_unitary: bool = False) -> types.Op: ... +def control(op: types.Op, *, n_controls: int) -> types.Op: + """ + Create a controlled operator. + + Note, that when considering atom loss, the operator will not be applied if + any of the controls has been lost. + + Args: + operator: The operator to apply under the control. + n_controls: The number qubits to be used as control. + + Returns: + Operator + """ + ... @_wraps(stmts.Identity) -def identity(*, size: int) -> types.Op: ... +def identity(*, sites: int) -> types.Op: ... @_wraps(stmts.Rot) @@ -75,6 +97,14 @@ def spin_n() -> types.Op: ... def spin_p() -> types.Op: ... +@_wraps(stmts.U3) +def u(theta: float, phi: float, lam: float) -> types.Op: ... + + +@_wraps(stmts.PauliString) +def pauli_string(*, string: str) -> types.Op: ... + + # stdlibs @_ir.dialect_group(_structural_no_opt.add(dialect)) def op(self): diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index 9f948510..76200133 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -103,6 +103,7 @@ class ConstantUnitary(ConstantOp): ) +@statement(dialect=dialect) class U3(PrimitiveOp): traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)}) theta: ir.SSAValue = info.argument(types.Float) diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index 6d6c53c6..355b37d8 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -77,6 +77,7 @@ class MeasureAndReset(ir.Statement): @statement(dialect=dialect) class Reset(ir.Statement): + traits = frozenset({lowering.FromPythonCall()}) qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) @@ -98,6 +99,8 @@ def new(n_qubits: int) -> ilist.IList[Qubit, Any]: def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None: """Apply an operator to a list of qubits. + Note, that when considering atom loss, lost qubits will be skipped. + Args: operator: The operator to apply. qubits: The list of qubits to apply the operator to. The size of the list @@ -112,7 +115,7 @@ def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None: @overload def measure(input: Qubit) -> bool: ... @overload -def measure(input: ilist.IList[Qubit, Any] | list[Qubit]) -> list[bool]: ... +def measure(input: ilist.IList[Qubit, Any] | list[Qubit]) -> ilist.IList[bool, Any]: ... @wraps(MeasureAny) @@ -135,6 +138,20 @@ def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> No """Broadcast and apply an operator to a list of qubits. For example, an operator that expects 2 qubits can be applied to a list of 2n qubits, where n is an integer > 0. + For controlled operators, the list of qubits is interpreted as sets of (controls, targets). + For example + + ``` + apply(CX, [q0, q1, q2, q3]) + ``` + + is equivalent to + + ``` + apply(CX, [q0, q1]) + apply(CX, [q2, q3]) + ``` + Args: operator: The operator to broadcast and apply. qubits: The list of qubits to broadcast and apply the operator to. The size of the list @@ -147,14 +164,14 @@ def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> No @wraps(MeasureAndReset) -def measure_and_reset(qubits: ilist.IList[Qubit, Any]) -> int: +def measure_and_reset(qubits: ilist.IList[Qubit, Any]) -> ilist.IList[bool, Any]: """Measure the qubits in the list and reset them." Args: qubits: The list of qubits to measure and reset. Returns: - int: The result of the measurement. + list[bool]: The result of the measurement. """ ... diff --git a/src/bloqade/squin/wire.py b/src/bloqade/squin/wire.py index f2a70d0d..6734791b 100644 --- a/src/bloqade/squin/wire.py +++ b/src/bloqade/squin/wire.py @@ -8,10 +8,11 @@ from kirin import ir, types, interp, lowering from kirin.decl import info, statement +from kirin.lowering import wraps -from bloqade.types import QubitType +from bloqade.types import Qubit, QubitType -from .op.types import OpType +from .op.types import Op, OpType # from kirin.lowering import wraps @@ -121,3 +122,11 @@ class ConstPropWire(interp.MethodTable): def apply(self, interp, frame, stmt: Apply): return frame.get_values(stmt.inputs) + + +@wraps(Unwrap) +def unwrap(qubit: Qubit) -> Wire: ... + + +@wraps(Apply) +def apply(op: Op, w: Wire) -> Wire: ... diff --git a/test/pyqrack/test_squin.py b/test/pyqrack/test_squin.py new file mode 100644 index 00000000..f7928630 --- /dev/null +++ b/test/pyqrack/test_squin.py @@ -0,0 +1,514 @@ +import math + +import pytest +from kirin.dialects import ilist + +from bloqade import squin +from bloqade.pyqrack import PyQrack, PyQrackWire, PyQrackQubit + + +def test_qubit(): + @squin.kernel + def new(): + return squin.qubit.new(3) + + new.print() + + target = PyQrack( + 3, pyqrack_options={"isBinaryDecisionTree": False, "isStabilizerHybrid": True} + ) + result = target.run(new) + assert isinstance(result, ilist.IList) + assert isinstance(qubit := result[0], PyQrackQubit) + + out = qubit.sim_reg.out_ket() + assert out == [1.0] + [0.0] * (2**3 - 1) + + @squin.kernel + def m(): + q = squin.qubit.new(3) + m = squin.qubit.measure(q) + squin.qubit.reset(q) + return m + + target = PyQrack(3) + result = target.run(m) + assert isinstance(result, ilist.IList) + assert result.data == [0, 0, 0] + + @squin.kernel + def measure_and_reset(): + q = squin.qubit.new(3) + m = squin.qubit.measure_and_reset(q) + return m + + target = PyQrack(3) + result = target.run(measure_and_reset) + assert isinstance(result, ilist.IList) + assert result.data == [0, 0, 0] + + +def test_x(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + x = squin.op.x() + squin.qubit.apply(x, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(main) + assert result == 1 + + +@pytest.mark.parametrize( + "op_name", + [ + "x", + "y", + "z", + "h", + "s", + "t", + ], +) +def test_basic_ops(op_name: str): + @squin.kernel + def main(): + q = squin.qubit.new(1) + op = getattr(squin.op, op_name)() + squin.qubit.apply(op, q) + return q + + target = PyQrack(1) + result = target.run(main) + assert isinstance(result, ilist.IList) + assert isinstance(qubit := result[0], PyQrackQubit) + + ket = qubit.sim_reg.out_ket() + n = sum([abs(k) ** 2 for k in ket]) + assert math.isclose(n, 1, abs_tol=1e-6) + + +def test_cx(): + @squin.kernel + def main(): + q = squin.qubit.new(2) + x = squin.op.x() + cx = squin.op.control(x, n_controls=1) + squin.qubit.apply(cx, q) + return squin.qubit.measure(q[1]) + + target = PyQrack(2) + result = target.run(main) + assert result == 0 + + @squin.kernel + def main2(): + q = squin.qubit.new(2) + x = squin.op.x() + id = squin.op.identity(sites=1) + cx = squin.op.control(x, n_controls=1) + squin.qubit.apply(squin.op.kron(x, id), q) + squin.qubit.apply(cx, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(2) + result = target.run(main2) + assert result == 1 + + @squin.kernel + def main3(): + q = squin.qubit.new(2) + x = squin.op.adjoint(squin.op.x()) + id = squin.op.identity(sites=1) + cx = squin.op.control(x, n_controls=1) + squin.qubit.apply(squin.op.kron(x, id), q) + squin.qubit.apply(cx, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(2) + result = target.run(main3) + assert result == 1 + + +def test_cxx(): + @squin.kernel + def main(): + q = squin.qubit.new(3) + x = squin.op.x() + cxx = squin.op.control(squin.op.kron(x, x), n_controls=1) + squin.qubit.apply(x, [q[0]]) + squin.qubit.apply(cxx, q) + return squin.qubit.measure(q) + + target = PyQrack(3) + result = target.run(main) + assert result == ilist.IList([1, 1, 1]) + + +def test_mult(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + x = squin.op.x() + id = squin.op.mult(x, x) + squin.qubit.apply(id, q) + return squin.qubit.measure(q[0]) + + main.print() + + target = PyQrack(1) + result = target.run(main) + + assert result == 0 + + +def test_kron(): + @squin.kernel + def main(): + q = squin.qubit.new(2) + x = squin.op.x() + k = squin.op.kron(x, x) + squin.qubit.apply(k, q) + return squin.qubit.measure(q) + + target = PyQrack(2) + result = target.run(main) + + assert result == ilist.IList([1, 1]) + + +def test_scale(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + x = squin.op.x() + + # TODO: replace by 2 * x once we have the rewrite + s = squin.op.scale(x, 2) + + squin.qubit.apply(s, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(main) + assert result == 1 + + +def test_phase(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + h = squin.op.h() + squin.qubit.apply(h, q) + + p = squin.op.shift(math.pi) + squin.qubit.apply(p, q) + + squin.qubit.apply(h, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(main) + assert result == 1 + + +def test_sp(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + sp = squin.op.spin_p() + squin.qubit.apply(sp, q) + return q + + target = PyQrack(1) + result = target.run(main) + assert isinstance(result, ilist.IList) + assert isinstance(qubit := result[0], PyQrackQubit) + + assert qubit.sim_reg.out_ket() == [0, 0] + + @squin.kernel + def main2(): + q = squin.qubit.new(1) + sn = squin.op.spin_n() + sp = squin.op.spin_p() + squin.qubit.apply(sn, q) + squin.qubit.apply(sp, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(main2) + assert result == 0 + + +def test_adjoint(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + x = squin.op.x() + xadj = squin.op.adjoint(x) + squin.qubit.apply(xadj, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(main) + assert result == 1 + + @squin.kernel + def adj_that_does_something(): + q = squin.qubit.new(1) + s = squin.op.s() + sadj = squin.op.adjoint(s) + h = squin.op.h() + + squin.qubit.apply(h, q) + squin.qubit.apply(s, q) + squin.qubit.apply(sadj, q) + squin.qubit.apply(h, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(adj_that_does_something) + assert result == 0 + + @squin.kernel + def adj_of_adj(): + q = squin.qubit.new(1) + s = squin.op.s() + sadj = squin.op.adjoint(s) + sadj_adj = squin.op.adjoint(sadj) + h = squin.op.h() + + squin.qubit.apply(h, q) + squin.qubit.apply(sadj, q) + squin.qubit.apply(sadj_adj, q) + squin.qubit.apply(h, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(adj_of_adj) + assert result == 0 + + @squin.kernel + def nested_adj(): + q = squin.qubit.new(1) + s = squin.op.s() + sadj = squin.op.adjoint(s) + s_nested_adj = squin.op.adjoint(squin.op.adjoint(squin.op.adjoint(sadj))) + + h = squin.op.h() + + squin.qubit.apply(h, q) + squin.qubit.apply(sadj, q) + squin.qubit.apply(s_nested_adj, q) + squin.qubit.apply(h, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(nested_adj) + assert result == 0 + + +def test_rot(): + @squin.kernel + def main_x(): + q = squin.qubit.new(1) + x = squin.op.x() + r = squin.op.rot(x, math.pi) + squin.qubit.apply(r, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(main_x) + assert result == 1 + + @squin.kernel + def main_y(): + q = squin.qubit.new(1) + y = squin.op.y() + r = squin.op.rot(y, math.pi) + squin.qubit.apply(r, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(main_y) + assert result == 1 + + @squin.kernel + def main_z(): + q = squin.qubit.new(1) + z = squin.op.z() + r = squin.op.rot(z, math.pi) + squin.qubit.apply(r, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(main_z) + assert result == 0 + + +def test_broadcast(): + @squin.kernel + def main(): + q = squin.qubit.new(3) + x = squin.op.x() + squin.qubit.broadcast(x, q) + return squin.qubit.measure(q) + + target = PyQrack(3) + result = target.run(main) + assert result == ilist.IList([1, 1, 1]) + + @squin.kernel + def multi_site_bc(): + q = squin.qubit.new(6) + x = squin.op.x() + + # invert controls + squin.qubit.apply(x, [q[0]]) + squin.qubit.apply(x, [q[1]]) + + cx = squin.op.control(x, n_controls=2) + squin.qubit.broadcast(cx, q) + return squin.qubit.measure(q) + + target = PyQrack(6) + result = target.run(multi_site_bc) + assert result == ilist.IList([1, 1, 1, 0, 0, 0]) + + @squin.kernel + def bc_size_mismatch(): + q = squin.qubit.new(5) + x = squin.op.x() + + # invert controls + squin.qubit.apply(x, [q[0]]) + squin.qubit.apply(x, [q[1]]) + + cx = squin.op.control(x, n_controls=2) + squin.qubit.broadcast(cx, q) + return squin.qubit.measure(q) + + target = PyQrack(5) + + with pytest.raises(RuntimeError): + target.run(bc_size_mismatch) + + +def test_u3(): + @squin.kernel + def broadcast_h(): + q = squin.qubit.new(3) + + # rotate around Y by pi/2, i.e. perform a hadamard + u = squin.op.u(math.pi / 2.0, 0, 0) + + squin.qubit.broadcast(u, q) + return q + + target = PyQrack(3) + q = target.run(broadcast_h) + + assert isinstance(q, ilist.IList) + assert isinstance(qubit := q[0], PyQrackQubit) + + out = qubit.sim_reg.out_ket() + + # remove global phase introduced by pyqrack + phase = out[0] / abs(out[0]) + out = [ele / phase for ele in out] + + for element in out: + assert math.isclose(element.real, 1 / math.sqrt(8), abs_tol=2.2e-7) + assert math.isclose(element.imag, 0, abs_tol=2.2e-7) + + @squin.kernel + def broadcast_adjoint(): + q = squin.qubit.new(3) + + # rotate around Y by pi/2, i.e. perform a hadamard + u = squin.op.u(math.pi / 2.0, 0, 0) + + squin.qubit.broadcast(u, q) + + # rotate back down + u_adj = squin.op.adjoint(u) + squin.qubit.broadcast(u_adj, q) + return squin.qubit.measure(q) + + target = PyQrack(3) + result = target.run(broadcast_adjoint) + assert result == ilist.IList([0, 0, 0]) + + +def test_projectors(): + @squin.kernel + def main_p0(): + q = squin.qubit.new(1) + h = squin.op.h() + p0 = squin.op.p0() + squin.qubit.apply(h, q) + squin.qubit.apply(p0, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(main_p0) + assert result == 0 + + @squin.kernel + def main_p1(): + q = squin.qubit.new(1) + h = squin.op.h() + p1 = squin.op.p1() + squin.qubit.apply(h, q) + squin.qubit.apply(p1, q) + return squin.qubit.measure(q[0]) + + target = PyQrack(1) + result = target.run(main_p1) + assert result == 1 + + +def test_pauli_str(): + @squin.kernel + def main(): + q = squin.qubit.new(3) + cstr = squin.op.pauli_string(string="XXX") + squin.qubit.apply(cstr, q) + return squin.qubit.measure(q) + + target = PyQrack(3) + result = target.run(main) + assert result == ilist.IList([1, 1, 1]) + + +def test_identity(): + @squin.kernel + def main(): + x = squin.op.x() + q = squin.qubit.new(3) + id = squin.op.identity(sites=2) + squin.qubit.apply(squin.op.kron(x, id), q) + return squin.qubit.measure(q) + + target = PyQrack(3) + result = target.run(main) + assert result == ilist.IList([1, 0, 0]) + + +@pytest.mark.xfail +def test_wire(): + @squin.wired + def main(): + q = squin.qubit.new(1) + w = squin.wire.unwrap(q[0]) + x = squin.op.x() + squin.wire.apply(x, w) + return w + + target = PyQrack(1) + result = target.run(main) + assert isinstance(result, PyQrackWire) + assert result.qubit.sim_reg.out_ket() == [0, 1]