Skip to content

Commit 1a67a03

Browse files
committed
Merge branch 'main' into 19-rewrite-from-squin-to-stim
2 parents e45c10c + 5a81a6a commit 1a67a03

File tree

25 files changed

+256
-165
lines changed

25 files changed

+256
-165
lines changed

src/bloqade/analysis/address/impls.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,10 @@ def unwrap(
192192

193193
origin_qubit = frame.get(stmt.qubit)
194194

195-
return (AddressWire(origin_qubit=origin_qubit),)
195+
if isinstance(origin_qubit, AddressQubit):
196+
return (AddressWire(origin_qubit=origin_qubit),)
197+
else:
198+
return (Address.top(),)
196199

197200
@interp.impl(squin.wire.Apply)
198201
def apply(
@@ -201,14 +204,7 @@ def apply(
201204
frame: ForwardFrame[Address],
202205
stmt: squin.wire.Apply,
203206
):
204-
205-
origin_qubits = tuple(
206-
[frame.get(input_elem).origin_qubit for input_elem in stmt.inputs]
207-
)
208-
new_address_wires = tuple(
209-
[AddressWire(origin_qubit=origin_qubit) for origin_qubit in origin_qubits]
210-
)
211-
return new_address_wires
207+
return frame.get_values(stmt.inputs)
212208

213209
@interp.impl(squin.wire.MeasureAndReset)
214210
def measure_and_reset(

src/bloqade/analysis/address/lattice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,5 @@ class AddressWire(Address):
8181

8282
def is_subseteq(self, other: Address) -> bool:
8383
if isinstance(other, AddressWire):
84-
return self.origin_qubit == self.origin_qubit
84+
return self.origin_qubit == other.origin_qubit
8585
return False

src/bloqade/noise/native/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,9 @@ class MoveNoiseModelABC(abc.ABC):
102102
params: MoveNoiseParams = field(default_factory=MoveNoiseParams)
103103
"""Parameters for calculating move noise."""
104104

105-
@classmethod
106105
@abc.abstractmethod
107106
def parallel_cz_errors(
108-
cls, ctrls: List[int], qargs: List[int], rest: List[int]
107+
self, ctrls: List[int], qargs: List[int], rest: List[int]
109108
) -> Dict[Tuple[float, float, float, float], List[int]]:
110109
"""Takes a set of ctrls and qargs and returns a noise model for all qubits."""
111110
pass

src/bloqade/pyqrack/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from .reg import (
22
CBitRef as CBitRef,
33
CRegister as CRegister,
4-
PyQrackReg as PyQrackReg,
54
QubitState as QubitState,
65
Measurement as Measurement,
76
PyQrackQubit as PyQrackQubit,

src/bloqade/pyqrack/noise/native.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@ def atom_loss_channel(
9393

9494
for qarg in active_qubits:
9595
if interp.rng_state.uniform() <= stmt.prob:
96-
sim_reg = qarg.ref.sim_reg
97-
sim_reg.force_m(qarg.addr, 0)
96+
qarg.sim_reg.m(qarg.addr)
9897
qarg.drop()
9998

10099
return ()

src/bloqade/pyqrack/qasm2/core.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from typing import Any
2+
13
from kirin import interp
4+
from kirin.interp import InterpreterError
5+
from kirin.dialects import ilist
26

37
from bloqade.pyqrack.reg import (
48
CBitRef,
59
CRegister,
6-
PyQrackReg,
710
QubitState,
811
Measurement,
912
PyQrackQubit,
@@ -19,14 +22,13 @@ def qreg_new(
1922
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegNew
2023
):
2124
n_qubits: int = frame.get(stmt.n_qubits)
22-
return (
23-
PyQrackReg(
24-
size=n_qubits,
25-
sim_reg=interp.memory.sim_reg,
26-
addrs=interp.memory.allocate(n_qubits),
27-
qubit_state=[QubitState.Active] * n_qubits,
28-
),
25+
qreg = ilist.IList(
26+
[
27+
PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active)
28+
for i in interp.memory.allocate(n_qubits=n_qubits)
29+
]
2930
)
31+
return (qreg,)
3032

3133
@interp.impl(core.CRegNew)
3234
def creg_new(
@@ -39,7 +41,9 @@ def creg_new(
3941
def qreg_get(
4042
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegGet
4143
):
42-
return (PyQrackQubit(ref=frame.get(stmt.reg), pos=frame.get(stmt.idx)),)
44+
reg = frame.get(stmt.reg)
45+
i = frame.get(stmt.idx)
46+
return (reg[i],)
4347

4448
@interp.impl(core.CRegGet)
4549
def creg_get(
@@ -51,12 +55,25 @@ def creg_get(
5155
def measure(
5256
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.Measure
5357
):
54-
qarg: PyQrackQubit = frame.get(stmt.qarg)
55-
carg: CBitRef = frame.get(stmt.carg)
56-
if qarg.is_active():
57-
carg.set_value(Measurement(qarg.sim_reg.m(qarg.addr)))
58+
qarg: PyQrackQubit | ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qarg)
59+
carg: CBitRef | CRegister = frame.get(stmt.carg)
60+
61+
if isinstance(qarg, PyQrackQubit) and isinstance(carg, CBitRef):
62+
if qarg.is_active():
63+
carg.set_value(Measurement(qarg.sim_reg.m(qarg.addr)))
64+
else:
65+
carg.set_value(interp.loss_m_result)
66+
elif isinstance(qarg, ilist.IList) and isinstance(carg, CRegister):
67+
for i, qubit in enumerate(qarg):
68+
cbit = CBitRef(carg, i)
69+
if qubit.is_active():
70+
cbit.set_value(Measurement(qubit.sim_reg.m(qubit.addr)))
71+
else:
72+
cbit.set_value(interp.loss_m_result)
5873
else:
59-
carg.set_value(interp.loss_m_result)
74+
raise InterpreterError(
75+
f"Expected measure call on either a single qubit and classical bit, or two registers, but got the types {type(qarg)} and {type(carg)}"
76+
)
6077

6178
return ()
6279

src/bloqade/pyqrack/qasm2/glob.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from kirin import interp
44
from kirin.dialects import ilist
55

6-
from bloqade.pyqrack.reg import PyQrackReg
6+
from bloqade.pyqrack.reg import PyQrackQubit
77
from bloqade.pyqrack.base import PyQrackInterpreter
88
from bloqade.qasm2.dialects import glob
99

@@ -12,7 +12,9 @@
1212
class PyQrackMethods(interp.MethodTable):
1313
@interp.impl(glob.UGate)
1414
def ugate(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: glob.UGate):
15-
registers: ilist.IList[PyQrackReg, Any] = frame.get(stmt.registers)
15+
registers: ilist.IList[ilist.IList[PyQrackQubit, Any], Any] = frame.get(
16+
stmt.registers
17+
)
1618
theta, phi, lam = (
1719
frame.get(stmt.theta),
1820
frame.get(stmt.phi),

src/bloqade/pyqrack/reg.py

Lines changed: 11 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import enum
2-
from typing import TYPE_CHECKING, List
2+
from typing import TYPE_CHECKING
33
from dataclasses import dataclass
44

5-
from bloqade.qasm2.types import QReg, Qubit
5+
from bloqade.qasm2.types import Qubit
66

77
if TYPE_CHECKING:
88
from pyqrack import QrackSimulator
@@ -45,57 +45,18 @@ class QubitState(enum.Enum):
4545
Lost = enum.auto()
4646

4747

48-
@dataclass(frozen=True)
49-
class PyQrackReg(QReg): # TODO: clean up implementation with list base class
50-
"""Simulation runtime value of a quantum register."""
51-
52-
size: int
53-
"""The number of qubits in this register."""
54-
55-
sim_reg: "QrackSimulator"
56-
"""The register of the simulator."""
57-
58-
addrs: tuple[int, ...]
59-
"""The global addresses of the qubits in this register."""
60-
61-
qubit_state: List[QubitState]
62-
"""The state of each qubit in this register."""
63-
64-
def drop(self, pos: int):
65-
"""Drop the qubit at the given position in-place.
66-
67-
Args
68-
pos (int): The position of the qubit to drop.
69-
70-
"""
71-
assert self.qubit_state[pos] is QubitState.Active, "Qubit already lost"
72-
self.qubit_state[pos] = QubitState.Lost
73-
74-
def __getitem__(self, pos: int):
75-
if not 0 <= pos < self.size:
76-
raise IndexError("Qubit index out of bounds of register.")
77-
return PyQrackQubit(self, pos)
78-
79-
80-
@dataclass(frozen=True)
48+
@dataclass
8149
class PyQrackQubit(Qubit):
8250
"""The runtime representation of a qubit reference."""
8351

84-
ref: PyQrackReg
85-
"""The quantum register that is holding this qubit."""
52+
addr: int
53+
"""The address of this qubit in the quantum register."""
8654

87-
pos: int
88-
"""The position of this qubit in the quantum register."""
89-
90-
@property
91-
def sim_reg(self):
92-
"""The register of the simulator."""
93-
return self.ref.sim_reg
55+
sim_reg: "QrackSimulator"
56+
"""The register of the simulator."""
9457

95-
@property
96-
def addr(self) -> int:
97-
"""The global address of the qubit."""
98-
return self.ref.addrs[self.pos]
58+
state: QubitState
59+
"""The state of the qubit (active/lost)"""
9960

10061
def is_active(self) -> bool:
10162
"""Check if the qubit is active.
@@ -104,8 +65,8 @@ def is_active(self) -> bool:
10465
True if the qubit is active, False otherwise.
10566
10667
"""
107-
return self.ref.qubit_state[self.pos] is QubitState.Active
68+
return self.state is QubitState.Active
10869

10970
def drop(self):
11071
"""Drop the qubit in-place."""
111-
self.ref.drop(self.pos)
72+
self.state = QubitState.Lost

src/bloqade/qasm2/_wrappers.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import overload
2+
13
from kirin.lowering import wraps
24

35
from .types import Bit, CReg, QReg, Qubit
@@ -58,7 +60,7 @@ def reset(qarg: Qubit) -> None:
5860
...
5961

6062

61-
@wraps(core.Measure)
63+
@overload
6264
def measure(qarg: Qubit, cbit: Bit) -> None:
6365
"""
6466
Measure the qubit `qarg` and store the result in the classical bit `cbit`.
@@ -70,6 +72,22 @@ def measure(qarg: Qubit, cbit: Bit) -> None:
7072
...
7173

7274

75+
@overload
76+
def measure(qarg: QReg, carg: CReg) -> None:
77+
"""
78+
Measure each qubit in the quantum register `qarg` and store the result in the classical register `carg`.
79+
80+
Args:
81+
qarg: The quantum register to measure.
82+
carg: The classical bit to store the result in.
83+
"""
84+
...
85+
86+
87+
@wraps(core.Measure)
88+
def measure(qarg, carg) -> None: ...
89+
90+
7391
@wraps(uop.CX)
7492
def cx(ctrl: Qubit, qarg: Qubit) -> None:
7593
"""

src/bloqade/qasm2/dialects/core/stmts.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,21 @@ class Measure(ir.Statement):
4646

4747
name = "measure"
4848
traits = frozenset({lowering.FromPythonCall()})
49-
qarg: ir.SSAValue = info.argument(QubitType)
50-
"""qarg (Qubit): The qubit to measure."""
51-
carg: ir.SSAValue = info.argument(BitType)
52-
"""carg (Bit): The bit to store the result in."""
49+
qarg: ir.SSAValue = info.argument(QubitType | QRegType)
50+
"""qarg (Qubit | QReg): The qubit or quantum register to measure."""
51+
carg: ir.SSAValue = info.argument(BitType | CRegType)
52+
"""carg (Bit | CReg): The bit or register to store the result in."""
53+
54+
def check_type(self) -> None:
55+
qarg_is_qubit = self.qarg.type.is_subseteq(QubitType)
56+
carg_is_bit = self.carg.type.is_subseteq(BitType)
57+
if (qarg_is_qubit and not carg_is_bit) or (not qarg_is_qubit and carg_is_bit):
58+
raise ir.TypeCheckError(
59+
self,
60+
"Can't perform measurement with single (qu)bit and an entire register!",
61+
help="Instead of `measure(qreg[i], creg)` or `measure(qreg, creg[i])`"
62+
"use `measure(qreg[i], creg[j])` or `measure(qreg, creg)`, respectively.",
63+
)
5364

5465

5566
@statement(dialect=dialect)

0 commit comments

Comments
 (0)