diff --git a/src/bloqade/pyqrack/__init__.py b/src/bloqade/pyqrack/__init__.py index 4c6ebf59..27fd3e54 100644 --- a/src/bloqade/pyqrack/__init__.py +++ b/src/bloqade/pyqrack/__init__.py @@ -1,7 +1,6 @@ from .reg import ( CBitRef as CBitRef, CRegister as CRegister, - PyQrackReg as PyQrackReg, QubitState as QubitState, Measurement as Measurement, PyQrackQubit as PyQrackQubit, diff --git a/src/bloqade/pyqrack/noise/native.py b/src/bloqade/pyqrack/noise/native.py index e5002c33..13192a77 100644 --- a/src/bloqade/pyqrack/noise/native.py +++ b/src/bloqade/pyqrack/noise/native.py @@ -93,7 +93,7 @@ def atom_loss_channel( for qarg in active_qubits: if interp.rng_state.uniform() <= stmt.prob: - qarg.ref.sim_reg.m(qarg.addr) + qarg.sim_reg.m(qarg.addr) qarg.drop() return () diff --git a/src/bloqade/pyqrack/qasm2/core.py b/src/bloqade/pyqrack/qasm2/core.py index aacad31d..4b6430db 100644 --- a/src/bloqade/pyqrack/qasm2/core.py +++ b/src/bloqade/pyqrack/qasm2/core.py @@ -1,9 +1,12 @@ +from typing import Any + from kirin import interp +from kirin.interp import InterpreterError +from kirin.dialects import ilist from bloqade.pyqrack.reg import ( CBitRef, CRegister, - PyQrackReg, QubitState, Measurement, PyQrackQubit, @@ -19,14 +22,13 @@ def qreg_new( self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegNew ): n_qubits: int = frame.get(stmt.n_qubits) - return ( - PyQrackReg( - size=n_qubits, - sim_reg=interp.memory.sim_reg, - addrs=interp.memory.allocate(n_qubits), - qubit_state=[QubitState.Active] * n_qubits, - ), + qreg = ilist.IList( + [ + PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active) + for i in interp.memory.allocate(n_qubits=n_qubits) + ] ) + return (qreg,) @interp.impl(core.CRegNew) def creg_new( @@ -39,7 +41,9 @@ def creg_new( def qreg_get( self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegGet ): - return (PyQrackQubit(ref=frame.get(stmt.reg), pos=frame.get(stmt.idx)),) + reg = frame.get(stmt.reg) + i = frame.get(stmt.idx) + return (reg[i],) @interp.impl(core.CRegGet) def creg_get( @@ -51,7 +55,7 @@ def creg_get( def measure( self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.Measure ): - qarg: PyQrackQubit | PyQrackReg = frame.get(stmt.qarg) + qarg: PyQrackQubit | ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qarg) carg: CBitRef | CRegister = frame.get(stmt.carg) if isinstance(qarg, PyQrackQubit) and isinstance(carg, CBitRef): @@ -59,20 +63,15 @@ def measure( carg.set_value(Measurement(qarg.sim_reg.m(qarg.addr))) else: carg.set_value(interp.loss_m_result) - elif isinstance(qarg, PyQrackReg) and isinstance(carg, CRegister): - # TODO: clean up iteration after PyQrackReg is refactored - for i in range(qarg.size): - qubit = qarg[i] - - # TODO: make this consistent with PyQrackReg __getitem__ ? + elif isinstance(qarg, ilist.IList) and isinstance(carg, CRegister): + for i, qubit in enumerate(qarg): cbit = CBitRef(carg, i) - if qubit.is_active(): - cbit.set_value(Measurement(qarg.sim_reg.m(qubit.addr))) + cbit.set_value(Measurement(qubit.sim_reg.m(qubit.addr))) else: cbit.set_value(interp.loss_m_result) else: - raise RuntimeError( + raise InterpreterError( f"Expected measure call on either a single qubit and classical bit, or two registers, but got the types {type(qarg)} and {type(carg)}" ) diff --git a/src/bloqade/pyqrack/qasm2/glob.py b/src/bloqade/pyqrack/qasm2/glob.py index 730f71b6..1a99a7e1 100644 --- a/src/bloqade/pyqrack/qasm2/glob.py +++ b/src/bloqade/pyqrack/qasm2/glob.py @@ -3,7 +3,7 @@ from kirin import interp from kirin.dialects import ilist -from bloqade.pyqrack.reg import PyQrackReg +from bloqade.pyqrack.reg import PyQrackQubit from bloqade.pyqrack.base import PyQrackInterpreter from bloqade.qasm2.dialects import glob @@ -12,7 +12,9 @@ class PyQrackMethods(interp.MethodTable): @interp.impl(glob.UGate) def ugate(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: glob.UGate): - registers: ilist.IList[PyQrackReg, Any] = frame.get(stmt.registers) + registers: ilist.IList[ilist.IList[PyQrackQubit, Any], Any] = frame.get( + stmt.registers + ) theta, phi, lam = ( frame.get(stmt.theta), frame.get(stmt.phi), diff --git a/src/bloqade/pyqrack/reg.py b/src/bloqade/pyqrack/reg.py index 38838e2e..7f498776 100644 --- a/src/bloqade/pyqrack/reg.py +++ b/src/bloqade/pyqrack/reg.py @@ -1,8 +1,8 @@ import enum -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING from dataclasses import dataclass -from bloqade.qasm2.types import QReg, Qubit +from bloqade.qasm2.types import Qubit if TYPE_CHECKING: from pyqrack import QrackSimulator @@ -45,57 +45,18 @@ class QubitState(enum.Enum): Lost = enum.auto() -@dataclass(frozen=True) -class PyQrackReg(QReg): # TODO: clean up implementation with list base class - """Simulation runtime value of a quantum register.""" - - size: int - """The number of qubits in this register.""" - - sim_reg: "QrackSimulator" - """The register of the simulator.""" - - addrs: tuple[int, ...] - """The global addresses of the qubits in this register.""" - - qubit_state: List[QubitState] - """The state of each qubit in this register.""" - - def drop(self, pos: int): - """Drop the qubit at the given position in-place. - - Args - pos (int): The position of the qubit to drop. - - """ - assert self.qubit_state[pos] is QubitState.Active, "Qubit already lost" - self.qubit_state[pos] = QubitState.Lost - - def __getitem__(self, pos: int): - if not 0 <= pos < self.size: - raise IndexError("Qubit index out of bounds of register.") - return PyQrackQubit(self, pos) - - -@dataclass(frozen=True) +@dataclass class PyQrackQubit(Qubit): """The runtime representation of a qubit reference.""" - ref: PyQrackReg - """The quantum register that is holding this qubit.""" + addr: int + """The address of this qubit in the quantum register.""" - pos: int - """The position of this qubit in the quantum register.""" - - @property - def sim_reg(self): - """The register of the simulator.""" - return self.ref.sim_reg + sim_reg: "QrackSimulator" + """The register of the simulator.""" - @property - def addr(self) -> int: - """The global address of the qubit.""" - return self.ref.addrs[self.pos] + state: QubitState + """The state of the qubit (active/lost)""" def is_active(self) -> bool: """Check if the qubit is active. @@ -104,8 +65,8 @@ def is_active(self) -> bool: True if the qubit is active, False otherwise. """ - return self.ref.qubit_state[self.pos] is QubitState.Active + return self.state is QubitState.Active def drop(self): """Drop the qubit in-place.""" - self.ref.drop(self.pos) + self.state = QubitState.Lost diff --git a/src/bloqade/qasm2/types.py b/src/bloqade/qasm2/types.py index 827bfde0..7dd4e136 100644 --- a/src/bloqade/qasm2/types.py +++ b/src/bloqade/qasm2/types.py @@ -1,4 +1,5 @@ from kirin import types +from kirin.dialects import ilist from bloqade.types import Qubit as Qubit, QubitType as QubitType @@ -15,11 +16,7 @@ class Bit: pass -class QReg: - """Runtime representation of a quantum register.""" - - def __getitem__(self, index) -> Qubit: - raise NotImplementedError("cannot call __getitem__ outside of a kernel") +QReg = ilist.IList[Qubit, types.Any] class CReg: @@ -32,7 +29,7 @@ def __getitem__(self, index) -> Bit: BitType = types.PyClass(Bit) """Kirin type for a classical bit.""" -QRegType = types.PyClass(QReg) +QRegType = ilist.IListType[QubitType, types.Any] """Kirin type for a quantum register.""" CRegType = types.PyClass(CReg) diff --git a/test/pyqrack/runtime/noise/native/test_loss.py b/test/pyqrack/runtime/noise/native/test_loss.py index 174350ae..356aec9b 100644 --- a/test/pyqrack/runtime/noise/native/test_loss.py +++ b/test/pyqrack/runtime/noise/native/test_loss.py @@ -1,10 +1,12 @@ +from typing import Literal from unittest.mock import Mock from kirin import ir +from kirin.dialects import ilist from bloqade import qasm2 from bloqade.noise import native -from bloqade.pyqrack import PyQrackInterpreter, reg +from bloqade.pyqrack import PyQrackQubit, PyQrackInterpreter, reg from bloqade.pyqrack.base import MockMemory simulation = qasm2.extended.add(native) @@ -34,12 +36,12 @@ def test_atom_loss(c: qasm2.CReg): input = reg.CRegister(1) memory = MockMemory() - result: reg.PyQrackReg = ( + result: ilist.IList[PyQrackQubit, Literal[2]] = ( PyQrackInterpreter(simulation, memory=memory, rng_state=rng_state) .run(test_atom_loss, (input,)) .expect() ) - assert result.qubit_state[0] is reg.QubitState.Lost - assert result.qubit_state[1] is reg.QubitState.Active + assert result[0].state is reg.QubitState.Lost + assert result[1].state is reg.QubitState.Active assert input[0] is reg.Measurement.One diff --git a/test/pyqrack/test_target.py b/test/pyqrack/test_target.py index 56156dad..b1f23729 100644 --- a/test/pyqrack/test_target.py +++ b/test/pyqrack/test_target.py @@ -2,9 +2,10 @@ import pytest from kirin import ir +from kirin.dialects import ilist from bloqade import qasm2 -from bloqade.pyqrack import PyQrack, reg +from bloqade.pyqrack import PyQrack, PyQrackQubit, reg def test_target(): @@ -23,9 +24,10 @@ def ghz(): q = target.run(ghz) - assert isinstance(q, reg.PyQrackReg) + assert isinstance(q, ilist.IList) + assert isinstance(qubit := q[0], PyQrackQubit) - out = q.sim_reg.out_ket() + out = qubit.sim_reg.out_ket() norm = math.sqrt(sum(abs(ele) ** 2 for ele in out)) phase = out[0] / abs(out[0]) @@ -56,9 +58,10 @@ def global_h(): target = PyQrack(3) q = target.run(global_h) - assert isinstance(q, reg.PyQrackReg) + assert isinstance(q, ilist.IList) + assert isinstance(qubit := q[0], PyQrackQubit) - out = q.sim_reg.out_ket() + out = qubit.sim_reg.out_ket() # remove global phase introduced by pyqrack phase = out[0] / abs(out[0]) @@ -103,9 +106,10 @@ def multiple_registers(): target = PyQrack(6) q1 = target.run(multiple_registers) - assert isinstance(q1, reg.PyQrackReg) + assert isinstance(q1, ilist.IList) + assert isinstance(qubit := q1[0], PyQrackQubit) - out = q1.sim_reg.out_ket() + out = qubit.sim_reg.out_ket() assert out[0] == 1 for i in range(1, len(out)): @@ -148,3 +152,19 @@ def measurement_that_errors(): q = qasm2.qreg(1) c = qasm2.creg(1) qasm2.measure(q[0], c) + + +def test_qreg_parallel(): + # test for #161 + @qasm2.extended + def parallel(): + qreg = qasm2.qreg(4) + creg = qasm2.creg(4) + qasm2.parallel.u(qreg, theta=math.pi, phi=0.0, lam=0.0) + qasm2.measure(qreg, creg) + return creg + + target = PyQrack(4) + result = target.run(parallel) + + assert result == [reg.Measurement.One] * 4