Skip to content

Commit e1b4e14

Browse files
committed
Replace QReg by IList[Qubit]
1 parent 2e9518c commit e1b4e14

File tree

8 files changed

+68
-89
lines changed

8 files changed

+68
-89
lines changed

src/bloqade/pyqrack/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from .reg import (
22
CBitRef as CBitRef,
33
CRegister as CRegister,
4-
PyQrackReg as PyQrackReg,
54
QubitState as QubitState,
65
Measurement as Measurement,
76
PyQrackQubit as PyQrackQubit,

src/bloqade/pyqrack/noise/native.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def atom_loss_channel(
9393

9494
for qarg in active_qubits:
9595
if interp.rng_state.uniform() <= stmt.prob:
96-
qarg.ref.sim_reg.m(qarg.addr)
96+
qarg.sim_reg.m(qarg.addr)
9797
qarg.drop()
9898

9999
return ()

src/bloqade/pyqrack/qasm2/core.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from typing import Any
2+
13
from kirin import interp
4+
from kirin.dialects import ilist
25

36
from bloqade.pyqrack.reg import (
47
CBitRef,
58
CRegister,
6-
PyQrackReg,
79
QubitState,
810
Measurement,
911
PyQrackQubit,
@@ -19,14 +21,13 @@ def qreg_new(
1921
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegNew
2022
):
2123
n_qubits: int = frame.get(stmt.n_qubits)
22-
return (
23-
PyQrackReg(
24-
size=n_qubits,
25-
sim_reg=interp.memory.sim_reg,
26-
addrs=interp.memory.allocate(n_qubits),
27-
qubit_state=[QubitState.Active] * n_qubits,
28-
),
24+
qreg = ilist.IList(
25+
[
26+
PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active)
27+
for i in interp.memory.allocate(n_qubits=n_qubits)
28+
]
2929
)
30+
return (qreg,)
3031

3132
@interp.impl(core.CRegNew)
3233
def creg_new(
@@ -39,7 +40,9 @@ def creg_new(
3940
def qreg_get(
4041
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegGet
4142
):
42-
return (PyQrackQubit(ref=frame.get(stmt.reg), pos=frame.get(stmt.idx)),)
43+
reg = frame.get(stmt.reg)
44+
i = frame.get(stmt.idx)
45+
return (reg[i],)
4346

4447
@interp.impl(core.CRegGet)
4548
def creg_get(
@@ -51,24 +54,19 @@ def creg_get(
5154
def measure(
5255
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.Measure
5356
):
54-
qarg: PyQrackQubit | PyQrackReg = frame.get(stmt.qarg)
57+
qarg: PyQrackQubit | ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qarg)
5558
carg: CBitRef | CRegister = frame.get(stmt.carg)
5659

5760
if isinstance(qarg, PyQrackQubit) and isinstance(carg, CBitRef):
5861
if qarg.is_active():
5962
carg.set_value(Measurement(qarg.sim_reg.m(qarg.addr)))
6063
else:
6164
carg.set_value(interp.loss_m_result)
62-
elif isinstance(qarg, PyQrackReg) and isinstance(carg, CRegister):
63-
# TODO: clean up iteration after PyQrackReg is refactored
64-
for i in range(qarg.size):
65-
qubit = qarg[i]
66-
67-
# TODO: make this consistent with PyQrackReg __getitem__ ?
65+
elif isinstance(qarg, ilist.IList) and isinstance(carg, CRegister):
66+
for i, qubit in enumerate(qarg):
6867
cbit = CBitRef(carg, i)
69-
7068
if qubit.is_active():
71-
cbit.set_value(Measurement(qarg.sim_reg.m(qubit.addr)))
69+
cbit.set_value(Measurement(qubit.sim_reg.m(qubit.addr)))
7270
else:
7371
cbit.set_value(interp.loss_m_result)
7472
else:

src/bloqade/pyqrack/qasm2/glob.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from kirin import interp
44
from kirin.dialects import ilist
55

6-
from bloqade.pyqrack.reg import PyQrackReg
6+
from bloqade.pyqrack.reg import PyQrackQubit
77
from bloqade.pyqrack.base import PyQrackInterpreter
88
from bloqade.qasm2.dialects import glob
99

@@ -12,7 +12,9 @@
1212
class PyQrackMethods(interp.MethodTable):
1313
@interp.impl(glob.UGate)
1414
def ugate(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: glob.UGate):
15-
registers: ilist.IList[PyQrackReg, Any] = frame.get(stmt.registers)
15+
registers: ilist.IList[ilist.IList[PyQrackQubit, Any], Any] = frame.get(
16+
stmt.registers
17+
)
1618
theta, phi, lam = (
1719
frame.get(stmt.theta),
1820
frame.get(stmt.phi),

src/bloqade/pyqrack/reg.py

Lines changed: 11 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import enum
2-
from typing import TYPE_CHECKING, List
2+
from typing import TYPE_CHECKING
33
from dataclasses import dataclass
44

5-
from bloqade.qasm2.types import QReg, Qubit
5+
from bloqade.qasm2.types import Qubit
66

77
if TYPE_CHECKING:
88
from pyqrack import QrackSimulator
@@ -45,57 +45,18 @@ class QubitState(enum.Enum):
4545
Lost = enum.auto()
4646

4747

48-
@dataclass(frozen=True)
49-
class PyQrackReg(QReg): # TODO: clean up implementation with list base class
50-
"""Simulation runtime value of a quantum register."""
51-
52-
size: int
53-
"""The number of qubits in this register."""
54-
55-
sim_reg: "QrackSimulator"
56-
"""The register of the simulator."""
57-
58-
addrs: tuple[int, ...]
59-
"""The global addresses of the qubits in this register."""
60-
61-
qubit_state: List[QubitState]
62-
"""The state of each qubit in this register."""
63-
64-
def drop(self, pos: int):
65-
"""Drop the qubit at the given position in-place.
66-
67-
Args
68-
pos (int): The position of the qubit to drop.
69-
70-
"""
71-
assert self.qubit_state[pos] is QubitState.Active, "Qubit already lost"
72-
self.qubit_state[pos] = QubitState.Lost
73-
74-
def __getitem__(self, pos: int):
75-
if not 0 <= pos < self.size:
76-
raise IndexError("Qubit index out of bounds of register.")
77-
return PyQrackQubit(self, pos)
78-
79-
80-
@dataclass(frozen=True)
48+
@dataclass
8149
class PyQrackQubit(Qubit):
8250
"""The runtime representation of a qubit reference."""
8351

84-
ref: PyQrackReg
85-
"""The quantum register that is holding this qubit."""
52+
addr: int
53+
"""The address of this qubit in the quantum register."""
8654

87-
pos: int
88-
"""The position of this qubit in the quantum register."""
89-
90-
@property
91-
def sim_reg(self):
92-
"""The register of the simulator."""
93-
return self.ref.sim_reg
55+
sim_reg: "QrackSimulator"
56+
"""The register of the simulator."""
9457

95-
@property
96-
def addr(self) -> int:
97-
"""The global address of the qubit."""
98-
return self.ref.addrs[self.pos]
58+
state: QubitState
59+
"""The state of the qubit (active/lost)"""
9960

10061
def is_active(self) -> bool:
10162
"""Check if the qubit is active.
@@ -104,8 +65,8 @@ def is_active(self) -> bool:
10465
True if the qubit is active, False otherwise.
10566
10667
"""
107-
return self.ref.qubit_state[self.pos] is QubitState.Active
68+
return self.state is QubitState.Active
10869

10970
def drop(self):
11071
"""Drop the qubit in-place."""
111-
self.ref.drop(self.pos)
72+
self.state = QubitState.Lost

src/bloqade/qasm2/types.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from kirin import types
2+
from kirin.dialects import ilist
23

34
from bloqade.types import Qubit as Qubit, QubitType as QubitType
45

@@ -15,11 +16,7 @@ class Bit:
1516
pass
1617

1718

18-
class QReg:
19-
"""Runtime representation of a quantum register."""
20-
21-
def __getitem__(self, index) -> Qubit:
22-
raise NotImplementedError("cannot call __getitem__ outside of a kernel")
19+
QReg = ilist.IList[Qubit, types.Any]
2320

2421

2522
class CReg:
@@ -32,7 +29,7 @@ def __getitem__(self, index) -> Bit:
3229
BitType = types.PyClass(Bit)
3330
"""Kirin type for a classical bit."""
3431

35-
QRegType = types.PyClass(QReg)
32+
QRegType = ilist.IListType[QubitType, types.Any]
3633
"""Kirin type for a quantum register."""
3734

3835
CRegType = types.PyClass(CReg)

test/pyqrack/runtime/noise/native/test_loss.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from typing import Any
12
from unittest.mock import Mock
23

34
from kirin import ir
5+
from kirin.dialects import ilist
46

57
from bloqade import qasm2
68
from bloqade.noise import native
7-
from bloqade.pyqrack import PyQrackInterpreter, reg
9+
from bloqade.pyqrack import PyQrackQubit, PyQrackInterpreter, reg
810
from bloqade.pyqrack.base import MockMemory
911

1012
simulation = qasm2.extended.add(native)
@@ -34,12 +36,12 @@ def test_atom_loss(c: qasm2.CReg):
3436
input = reg.CRegister(1)
3537
memory = MockMemory()
3638

37-
result: reg.PyQrackReg = (
39+
result: ilist.IList[PyQrackQubit, Any] = (
3840
PyQrackInterpreter(simulation, memory=memory, rng_state=rng_state)
3941
.run(test_atom_loss, (input,))
4042
.expect()
4143
)
4244

43-
assert result.qubit_state[0] is reg.QubitState.Lost
44-
assert result.qubit_state[1] is reg.QubitState.Active
45+
assert result[0].state is reg.QubitState.Lost
46+
assert result[1].state is reg.QubitState.Active
4547
assert input[0] is reg.Measurement.One

test/pyqrack/test_target.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import pytest
44
from kirin import ir
5+
from kirin.dialects import ilist
56

67
from bloqade import qasm2
7-
from bloqade.pyqrack import PyQrack, reg
8+
from bloqade.pyqrack import PyQrack, PyQrackQubit, reg
89

910

1011
def test_target():
@@ -23,9 +24,10 @@ def ghz():
2324

2425
q = target.run(ghz)
2526

26-
assert isinstance(q, reg.PyQrackReg)
27+
assert isinstance(q, ilist.IList)
28+
assert isinstance(qubit := q[0], PyQrackQubit)
2729

28-
out = q.sim_reg.out_ket()
30+
out = qubit.sim_reg.out_ket()
2931

3032
norm = math.sqrt(sum(abs(ele) ** 2 for ele in out))
3133
phase = out[0] / abs(out[0])
@@ -56,9 +58,10 @@ def global_h():
5658
target = PyQrack(3)
5759
q = target.run(global_h)
5860

59-
assert isinstance(q, reg.PyQrackReg)
61+
assert isinstance(q, ilist.IList)
62+
assert isinstance(qubit := q[0], PyQrackQubit)
6063

61-
out = q.sim_reg.out_ket()
64+
out = qubit.sim_reg.out_ket()
6265

6366
# remove global phase introduced by pyqrack
6467
phase = out[0] / abs(out[0])
@@ -103,9 +106,10 @@ def multiple_registers():
103106
target = PyQrack(6)
104107
q1 = target.run(multiple_registers)
105108

106-
assert isinstance(q1, reg.PyQrackReg)
109+
assert isinstance(q1, ilist.IList)
110+
assert isinstance(qubit := q1[0], PyQrackQubit)
107111

108-
out = q1.sim_reg.out_ket()
112+
out = qubit.sim_reg.out_ket()
109113

110114
assert out[0] == 1
111115
for i in range(1, len(out)):
@@ -148,3 +152,19 @@ def measurement_that_errors():
148152
q = qasm2.qreg(1)
149153
c = qasm2.creg(1)
150154
qasm2.measure(q[0], c)
155+
156+
157+
def test_qreg_parallel():
158+
# test for #161
159+
@qasm2.extended
160+
def parallel():
161+
qreg = qasm2.qreg(4)
162+
creg = qasm2.creg(4)
163+
qasm2.parallel.u(qreg, theta=math.pi, phi=0.0, lam=0.0)
164+
qasm2.measure(qreg, creg)
165+
return creg
166+
167+
target = PyQrack(4)
168+
result = target.run(parallel)
169+
170+
assert result == [reg.Measurement.One] * 4

0 commit comments

Comments
 (0)