Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/bloqade/pyqrack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .reg import (
CBitRef as CBitRef,
CRegister as CRegister,
PyQrackReg as PyQrackReg,
QubitState as QubitState,
Measurement as Measurement,
PyQrackQubit as PyQrackQubit,
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/pyqrack/noise/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def atom_loss_channel(

for qarg in active_qubits:
if interp.rng_state.uniform() <= stmt.prob:
qarg.ref.sim_reg.m(qarg.addr)
qarg.sim_reg.m(qarg.addr)
qarg.drop()

return ()
37 changes: 18 additions & 19 deletions src/bloqade/pyqrack/qasm2/core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Any

from kirin import interp
from kirin.interp import InterpreterError
from kirin.dialects import ilist

from bloqade.pyqrack.reg import (
CBitRef,
CRegister,
PyQrackReg,
QubitState,
Measurement,
PyQrackQubit,
Expand All @@ -19,14 +22,13 @@
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegNew
):
n_qubits: int = frame.get(stmt.n_qubits)
return (
PyQrackReg(
size=n_qubits,
sim_reg=interp.memory.sim_reg,
addrs=interp.memory.allocate(n_qubits),
qubit_state=[QubitState.Active] * n_qubits,
),
qreg = ilist.IList(
[
PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active)
for i in interp.memory.allocate(n_qubits=n_qubits)
]
)
return (qreg,)

@interp.impl(core.CRegNew)
def creg_new(
Expand All @@ -39,7 +41,9 @@
def qreg_get(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegGet
):
return (PyQrackQubit(ref=frame.get(stmt.reg), pos=frame.get(stmt.idx)),)
reg = frame.get(stmt.reg)
i = frame.get(stmt.idx)
return (reg[i],)

@interp.impl(core.CRegGet)
def creg_get(
Expand All @@ -51,28 +55,23 @@
def measure(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.Measure
):
qarg: PyQrackQubit | PyQrackReg = frame.get(stmt.qarg)
qarg: PyQrackQubit | ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qarg)
carg: CBitRef | CRegister = frame.get(stmt.carg)

if isinstance(qarg, PyQrackQubit) and isinstance(carg, CBitRef):
if qarg.is_active():
carg.set_value(Measurement(qarg.sim_reg.m(qarg.addr)))
else:
carg.set_value(interp.loss_m_result)
elif isinstance(qarg, PyQrackReg) and isinstance(carg, CRegister):
# TODO: clean up iteration after PyQrackReg is refactored
for i in range(qarg.size):
qubit = qarg[i]

# TODO: make this consistent with PyQrackReg __getitem__ ?
elif isinstance(qarg, ilist.IList) and isinstance(carg, CRegister):
for i, qubit in enumerate(qarg):
cbit = CBitRef(carg, i)

if qubit.is_active():
cbit.set_value(Measurement(qarg.sim_reg.m(qubit.addr)))
cbit.set_value(Measurement(qubit.sim_reg.m(qubit.addr)))
else:
cbit.set_value(interp.loss_m_result)
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh BTW I missed this in the last PR, you should raise an InterpreterError here not a RuntimeError.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

specifically in the else branch here.

raise RuntimeError(
raise InterpreterError(

Check warning on line 74 in src/bloqade/pyqrack/qasm2/core.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/pyqrack/qasm2/core.py#L74

Added line #L74 was not covered by tests
f"Expected measure call on either a single qubit and classical bit, or two registers, but got the types {type(qarg)} and {type(carg)}"
)

Expand Down
6 changes: 4 additions & 2 deletions src/bloqade/pyqrack/qasm2/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from kirin import interp
from kirin.dialects import ilist

from bloqade.pyqrack.reg import PyQrackReg
from bloqade.pyqrack.reg import PyQrackQubit
from bloqade.pyqrack.base import PyQrackInterpreter
from bloqade.qasm2.dialects import glob

Expand All @@ -12,7 +12,9 @@
class PyQrackMethods(interp.MethodTable):
@interp.impl(glob.UGate)
def ugate(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: glob.UGate):
registers: ilist.IList[PyQrackReg, Any] = frame.get(stmt.registers)
registers: ilist.IList[ilist.IList[PyQrackQubit, Any], Any] = frame.get(
stmt.registers
)
theta, phi, lam = (
frame.get(stmt.theta),
frame.get(stmt.phi),
Expand Down
61 changes: 11 additions & 50 deletions src/bloqade/pyqrack/reg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import enum
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING
from dataclasses import dataclass

from bloqade.qasm2.types import QReg, Qubit
from bloqade.qasm2.types import Qubit

if TYPE_CHECKING:
from pyqrack import QrackSimulator
Expand Down Expand Up @@ -45,57 +45,18 @@ class QubitState(enum.Enum):
Lost = enum.auto()


@dataclass(frozen=True)
class PyQrackReg(QReg): # TODO: clean up implementation with list base class
"""Simulation runtime value of a quantum register."""

size: int
"""The number of qubits in this register."""

sim_reg: "QrackSimulator"
"""The register of the simulator."""

addrs: tuple[int, ...]
"""The global addresses of the qubits in this register."""

qubit_state: List[QubitState]
"""The state of each qubit in this register."""

def drop(self, pos: int):
"""Drop the qubit at the given position in-place.

Args
pos (int): The position of the qubit to drop.

"""
assert self.qubit_state[pos] is QubitState.Active, "Qubit already lost"
self.qubit_state[pos] = QubitState.Lost

def __getitem__(self, pos: int):
if not 0 <= pos < self.size:
raise IndexError("Qubit index out of bounds of register.")
return PyQrackQubit(self, pos)


@dataclass(frozen=True)
@dataclass
class PyQrackQubit(Qubit):
"""The runtime representation of a qubit reference."""

ref: PyQrackReg
"""The quantum register that is holding this qubit."""
addr: int
"""The address of this qubit in the quantum register."""

pos: int
"""The position of this qubit in the quantum register."""

@property
def sim_reg(self):
"""The register of the simulator."""
return self.ref.sim_reg
sim_reg: "QrackSimulator"
"""The register of the simulator."""

@property
def addr(self) -> int:
"""The global address of the qubit."""
return self.ref.addrs[self.pos]
state: QubitState
"""The state of the qubit (active/lost)"""

def is_active(self) -> bool:
"""Check if the qubit is active.
Expand All @@ -104,8 +65,8 @@ def is_active(self) -> bool:
True if the qubit is active, False otherwise.

"""
return self.ref.qubit_state[self.pos] is QubitState.Active
return self.state is QubitState.Active

def drop(self):
"""Drop the qubit in-place."""
self.ref.drop(self.pos)
self.state = QubitState.Lost
9 changes: 3 additions & 6 deletions src/bloqade/qasm2/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from kirin import types
from kirin.dialects import ilist

from bloqade.types import Qubit as Qubit, QubitType as QubitType

Expand All @@ -15,11 +16,7 @@ class Bit:
pass


class QReg:
"""Runtime representation of a quantum register."""

def __getitem__(self, index) -> Qubit:
raise NotImplementedError("cannot call __getitem__ outside of a kernel")
QReg = ilist.IList[Qubit, types.Any]


class CReg:
Expand All @@ -32,7 +29,7 @@ def __getitem__(self, index) -> Bit:
BitType = types.PyClass(Bit)
"""Kirin type for a classical bit."""

QRegType = types.PyClass(QReg)
QRegType = ilist.IListType[QubitType, types.Any]
"""Kirin type for a quantum register."""

CRegType = types.PyClass(CReg)
Expand Down
10 changes: 6 additions & 4 deletions test/pyqrack/runtime/noise/native/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Literal
from unittest.mock import Mock

from kirin import ir
from kirin.dialects import ilist

from bloqade import qasm2
from bloqade.noise import native
from bloqade.pyqrack import PyQrackInterpreter, reg
from bloqade.pyqrack import PyQrackQubit, PyQrackInterpreter, reg
from bloqade.pyqrack.base import MockMemory

simulation = qasm2.extended.add(native)
Expand Down Expand Up @@ -34,12 +36,12 @@ def test_atom_loss(c: qasm2.CReg):
input = reg.CRegister(1)
memory = MockMemory()

result: reg.PyQrackReg = (
result: ilist.IList[PyQrackQubit, Literal[2]] = (
PyQrackInterpreter(simulation, memory=memory, rng_state=rng_state)
.run(test_atom_loss, (input,))
.expect()
)

assert result.qubit_state[0] is reg.QubitState.Lost
assert result.qubit_state[1] is reg.QubitState.Active
assert result[0].state is reg.QubitState.Lost
assert result[1].state is reg.QubitState.Active
assert input[0] is reg.Measurement.One
34 changes: 27 additions & 7 deletions test/pyqrack/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pytest
from kirin import ir
from kirin.dialects import ilist

from bloqade import qasm2
from bloqade.pyqrack import PyQrack, reg
from bloqade.pyqrack import PyQrack, PyQrackQubit, reg


def test_target():
Expand All @@ -23,9 +24,10 @@ def ghz():

q = target.run(ghz)

assert isinstance(q, reg.PyQrackReg)
assert isinstance(q, ilist.IList)
assert isinstance(qubit := q[0], PyQrackQubit)

out = q.sim_reg.out_ket()
out = qubit.sim_reg.out_ket()

norm = math.sqrt(sum(abs(ele) ** 2 for ele in out))
phase = out[0] / abs(out[0])
Expand Down Expand Up @@ -56,9 +58,10 @@ def global_h():
target = PyQrack(3)
q = target.run(global_h)

assert isinstance(q, reg.PyQrackReg)
assert isinstance(q, ilist.IList)
assert isinstance(qubit := q[0], PyQrackQubit)

out = q.sim_reg.out_ket()
out = qubit.sim_reg.out_ket()

# remove global phase introduced by pyqrack
phase = out[0] / abs(out[0])
Expand Down Expand Up @@ -103,9 +106,10 @@ def multiple_registers():
target = PyQrack(6)
q1 = target.run(multiple_registers)

assert isinstance(q1, reg.PyQrackReg)
assert isinstance(q1, ilist.IList)
assert isinstance(qubit := q1[0], PyQrackQubit)

out = q1.sim_reg.out_ket()
out = qubit.sim_reg.out_ket()

assert out[0] == 1
for i in range(1, len(out)):
Expand Down Expand Up @@ -148,3 +152,19 @@ def measurement_that_errors():
q = qasm2.qreg(1)
c = qasm2.creg(1)
qasm2.measure(q[0], c)


def test_qreg_parallel():
# test for #161
@qasm2.extended
def parallel():
qreg = qasm2.qreg(4)
creg = qasm2.creg(4)
qasm2.parallel.u(qreg, theta=math.pi, phi=0.0, lam=0.0)
qasm2.measure(qreg, creg)
return creg

target = PyQrack(4)
result = target.run(parallel)

assert result == [reg.Measurement.One] * 4