Skip to content

Commit 5a81a6a

Browse files
authored
Replace QReg by IList[Qubit] (#181)
* Replace QReg by IList[Qubit] * Size of IList in test * Raise an InterpreterError instead of RuntimeError in measure impl
1 parent 2e9518c commit 5a81a6a

File tree

8 files changed

+70
-90
lines changed

8 files changed

+70
-90
lines changed

src/bloqade/pyqrack/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from .reg import (
22
CBitRef as CBitRef,
33
CRegister as CRegister,
4-
PyQrackReg as PyQrackReg,
54
QubitState as QubitState,
65
Measurement as Measurement,
76
PyQrackQubit as PyQrackQubit,

src/bloqade/pyqrack/noise/native.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def atom_loss_channel(
9393

9494
for qarg in active_qubits:
9595
if interp.rng_state.uniform() <= stmt.prob:
96-
qarg.ref.sim_reg.m(qarg.addr)
96+
qarg.sim_reg.m(qarg.addr)
9797
qarg.drop()
9898

9999
return ()

src/bloqade/pyqrack/qasm2/core.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from typing import Any
2+
13
from kirin import interp
4+
from kirin.interp import InterpreterError
5+
from kirin.dialects import ilist
26

37
from bloqade.pyqrack.reg import (
48
CBitRef,
59
CRegister,
6-
PyQrackReg,
710
QubitState,
811
Measurement,
912
PyQrackQubit,
@@ -19,14 +22,13 @@ def qreg_new(
1922
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegNew
2023
):
2124
n_qubits: int = frame.get(stmt.n_qubits)
22-
return (
23-
PyQrackReg(
24-
size=n_qubits,
25-
sim_reg=interp.memory.sim_reg,
26-
addrs=interp.memory.allocate(n_qubits),
27-
qubit_state=[QubitState.Active] * n_qubits,
28-
),
25+
qreg = ilist.IList(
26+
[
27+
PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active)
28+
for i in interp.memory.allocate(n_qubits=n_qubits)
29+
]
2930
)
31+
return (qreg,)
3032

3133
@interp.impl(core.CRegNew)
3234
def creg_new(
@@ -39,7 +41,9 @@ def creg_new(
3941
def qreg_get(
4042
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegGet
4143
):
42-
return (PyQrackQubit(ref=frame.get(stmt.reg), pos=frame.get(stmt.idx)),)
44+
reg = frame.get(stmt.reg)
45+
i = frame.get(stmt.idx)
46+
return (reg[i],)
4347

4448
@interp.impl(core.CRegGet)
4549
def creg_get(
@@ -51,28 +55,23 @@ def creg_get(
5155
def measure(
5256
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.Measure
5357
):
54-
qarg: PyQrackQubit | PyQrackReg = frame.get(stmt.qarg)
58+
qarg: PyQrackQubit | ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qarg)
5559
carg: CBitRef | CRegister = frame.get(stmt.carg)
5660

5761
if isinstance(qarg, PyQrackQubit) and isinstance(carg, CBitRef):
5862
if qarg.is_active():
5963
carg.set_value(Measurement(qarg.sim_reg.m(qarg.addr)))
6064
else:
6165
carg.set_value(interp.loss_m_result)
62-
elif isinstance(qarg, PyQrackReg) and isinstance(carg, CRegister):
63-
# TODO: clean up iteration after PyQrackReg is refactored
64-
for i in range(qarg.size):
65-
qubit = qarg[i]
66-
67-
# TODO: make this consistent with PyQrackReg __getitem__ ?
66+
elif isinstance(qarg, ilist.IList) and isinstance(carg, CRegister):
67+
for i, qubit in enumerate(qarg):
6868
cbit = CBitRef(carg, i)
69-
7069
if qubit.is_active():
71-
cbit.set_value(Measurement(qarg.sim_reg.m(qubit.addr)))
70+
cbit.set_value(Measurement(qubit.sim_reg.m(qubit.addr)))
7271
else:
7372
cbit.set_value(interp.loss_m_result)
7473
else:
75-
raise RuntimeError(
74+
raise InterpreterError(
7675
f"Expected measure call on either a single qubit and classical bit, or two registers, but got the types {type(qarg)} and {type(carg)}"
7776
)
7877

src/bloqade/pyqrack/qasm2/glob.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from kirin import interp
44
from kirin.dialects import ilist
55

6-
from bloqade.pyqrack.reg import PyQrackReg
6+
from bloqade.pyqrack.reg import PyQrackQubit
77
from bloqade.pyqrack.base import PyQrackInterpreter
88
from bloqade.qasm2.dialects import glob
99

@@ -12,7 +12,9 @@
1212
class PyQrackMethods(interp.MethodTable):
1313
@interp.impl(glob.UGate)
1414
def ugate(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: glob.UGate):
15-
registers: ilist.IList[PyQrackReg, Any] = frame.get(stmt.registers)
15+
registers: ilist.IList[ilist.IList[PyQrackQubit, Any], Any] = frame.get(
16+
stmt.registers
17+
)
1618
theta, phi, lam = (
1719
frame.get(stmt.theta),
1820
frame.get(stmt.phi),

src/bloqade/pyqrack/reg.py

Lines changed: 11 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import enum
2-
from typing import TYPE_CHECKING, List
2+
from typing import TYPE_CHECKING
33
from dataclasses import dataclass
44

5-
from bloqade.qasm2.types import QReg, Qubit
5+
from bloqade.qasm2.types import Qubit
66

77
if TYPE_CHECKING:
88
from pyqrack import QrackSimulator
@@ -45,57 +45,18 @@ class QubitState(enum.Enum):
4545
Lost = enum.auto()
4646

4747

48-
@dataclass(frozen=True)
49-
class PyQrackReg(QReg): # TODO: clean up implementation with list base class
50-
"""Simulation runtime value of a quantum register."""
51-
52-
size: int
53-
"""The number of qubits in this register."""
54-
55-
sim_reg: "QrackSimulator"
56-
"""The register of the simulator."""
57-
58-
addrs: tuple[int, ...]
59-
"""The global addresses of the qubits in this register."""
60-
61-
qubit_state: List[QubitState]
62-
"""The state of each qubit in this register."""
63-
64-
def drop(self, pos: int):
65-
"""Drop the qubit at the given position in-place.
66-
67-
Args
68-
pos (int): The position of the qubit to drop.
69-
70-
"""
71-
assert self.qubit_state[pos] is QubitState.Active, "Qubit already lost"
72-
self.qubit_state[pos] = QubitState.Lost
73-
74-
def __getitem__(self, pos: int):
75-
if not 0 <= pos < self.size:
76-
raise IndexError("Qubit index out of bounds of register.")
77-
return PyQrackQubit(self, pos)
78-
79-
80-
@dataclass(frozen=True)
48+
@dataclass
8149
class PyQrackQubit(Qubit):
8250
"""The runtime representation of a qubit reference."""
8351

84-
ref: PyQrackReg
85-
"""The quantum register that is holding this qubit."""
52+
addr: int
53+
"""The address of this qubit in the quantum register."""
8654

87-
pos: int
88-
"""The position of this qubit in the quantum register."""
89-
90-
@property
91-
def sim_reg(self):
92-
"""The register of the simulator."""
93-
return self.ref.sim_reg
55+
sim_reg: "QrackSimulator"
56+
"""The register of the simulator."""
9457

95-
@property
96-
def addr(self) -> int:
97-
"""The global address of the qubit."""
98-
return self.ref.addrs[self.pos]
58+
state: QubitState
59+
"""The state of the qubit (active/lost)"""
9960

10061
def is_active(self) -> bool:
10162
"""Check if the qubit is active.
@@ -104,8 +65,8 @@ def is_active(self) -> bool:
10465
True if the qubit is active, False otherwise.
10566
10667
"""
107-
return self.ref.qubit_state[self.pos] is QubitState.Active
68+
return self.state is QubitState.Active
10869

10970
def drop(self):
11071
"""Drop the qubit in-place."""
111-
self.ref.drop(self.pos)
72+
self.state = QubitState.Lost

src/bloqade/qasm2/types.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from kirin import types
2+
from kirin.dialects import ilist
23

34
from bloqade.types import Qubit as Qubit, QubitType as QubitType
45

@@ -15,11 +16,7 @@ class Bit:
1516
pass
1617

1718

18-
class QReg:
19-
"""Runtime representation of a quantum register."""
20-
21-
def __getitem__(self, index) -> Qubit:
22-
raise NotImplementedError("cannot call __getitem__ outside of a kernel")
19+
QReg = ilist.IList[Qubit, types.Any]
2320

2421

2522
class CReg:
@@ -32,7 +29,7 @@ def __getitem__(self, index) -> Bit:
3229
BitType = types.PyClass(Bit)
3330
"""Kirin type for a classical bit."""
3431

35-
QRegType = types.PyClass(QReg)
32+
QRegType = ilist.IListType[QubitType, types.Any]
3633
"""Kirin type for a quantum register."""
3734

3835
CRegType = types.PyClass(CReg)

test/pyqrack/runtime/noise/native/test_loss.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from typing import Literal
12
from unittest.mock import Mock
23

34
from kirin import ir
5+
from kirin.dialects import ilist
46

57
from bloqade import qasm2
68
from bloqade.noise import native
7-
from bloqade.pyqrack import PyQrackInterpreter, reg
9+
from bloqade.pyqrack import PyQrackQubit, PyQrackInterpreter, reg
810
from bloqade.pyqrack.base import MockMemory
911

1012
simulation = qasm2.extended.add(native)
@@ -34,12 +36,12 @@ def test_atom_loss(c: qasm2.CReg):
3436
input = reg.CRegister(1)
3537
memory = MockMemory()
3638

37-
result: reg.PyQrackReg = (
39+
result: ilist.IList[PyQrackQubit, Literal[2]] = (
3840
PyQrackInterpreter(simulation, memory=memory, rng_state=rng_state)
3941
.run(test_atom_loss, (input,))
4042
.expect()
4143
)
4244

43-
assert result.qubit_state[0] is reg.QubitState.Lost
44-
assert result.qubit_state[1] is reg.QubitState.Active
45+
assert result[0].state is reg.QubitState.Lost
46+
assert result[1].state is reg.QubitState.Active
4547
assert input[0] is reg.Measurement.One

test/pyqrack/test_target.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import pytest
44
from kirin import ir
5+
from kirin.dialects import ilist
56

67
from bloqade import qasm2
7-
from bloqade.pyqrack import PyQrack, reg
8+
from bloqade.pyqrack import PyQrack, PyQrackQubit, reg
89

910

1011
def test_target():
@@ -23,9 +24,10 @@ def ghz():
2324

2425
q = target.run(ghz)
2526

26-
assert isinstance(q, reg.PyQrackReg)
27+
assert isinstance(q, ilist.IList)
28+
assert isinstance(qubit := q[0], PyQrackQubit)
2729

28-
out = q.sim_reg.out_ket()
30+
out = qubit.sim_reg.out_ket()
2931

3032
norm = math.sqrt(sum(abs(ele) ** 2 for ele in out))
3133
phase = out[0] / abs(out[0])
@@ -56,9 +58,10 @@ def global_h():
5658
target = PyQrack(3)
5759
q = target.run(global_h)
5860

59-
assert isinstance(q, reg.PyQrackReg)
61+
assert isinstance(q, ilist.IList)
62+
assert isinstance(qubit := q[0], PyQrackQubit)
6063

61-
out = q.sim_reg.out_ket()
64+
out = qubit.sim_reg.out_ket()
6265

6366
# remove global phase introduced by pyqrack
6467
phase = out[0] / abs(out[0])
@@ -103,9 +106,10 @@ def multiple_registers():
103106
target = PyQrack(6)
104107
q1 = target.run(multiple_registers)
105108

106-
assert isinstance(q1, reg.PyQrackReg)
109+
assert isinstance(q1, ilist.IList)
110+
assert isinstance(qubit := q1[0], PyQrackQubit)
107111

108-
out = q1.sim_reg.out_ket()
112+
out = qubit.sim_reg.out_ket()
109113

110114
assert out[0] == 1
111115
for i in range(1, len(out)):
@@ -148,3 +152,19 @@ def measurement_that_errors():
148152
q = qasm2.qreg(1)
149153
c = qasm2.creg(1)
150154
qasm2.measure(q[0], c)
155+
156+
157+
def test_qreg_parallel():
158+
# test for #161
159+
@qasm2.extended
160+
def parallel():
161+
qreg = qasm2.qreg(4)
162+
creg = qasm2.creg(4)
163+
qasm2.parallel.u(qreg, theta=math.pi, phi=0.0, lam=0.0)
164+
qasm2.measure(qreg, creg)
165+
return creg
166+
167+
target = PyQrack(4)
168+
result = target.run(parallel)
169+
170+
assert result == [reg.Measurement.One] * 4

0 commit comments

Comments
 (0)