Skip to content

Commit 009a671

Browse files
committed
Add qubit ID and measurement ID statements (#499)
Implements #496 . I had to adapt the returns from the measurement PyQrack implementation to actually return a `Measurement` rather than just a boolean, in order to be able to store an ID on the resulting measurement. Not sure if the implementation is that clean here.
1 parent 11cf656 commit 009a671

File tree

6 files changed

+106
-7
lines changed

6 files changed

+106
-7
lines changed

src/bloqade/pyqrack/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,13 @@ class PyQrackInterpreter(Interpreter, typing.Generic[MemoryType]):
146146
loss_m_result: Measurement = field(default=Measurement.One, kw_only=True)
147147
"""The value of a measurement result when a qubit is lost."""
148148

149+
global_measurement_id: int = field(init=False, default=0)
150+
149151
def initialize(self) -> Self:
150152
super().initialize()
151153
self.memory.reset() # reset allocated qubits
152154
return self
155+
156+
def set_global_measurement_id(self, m: Measurement):
157+
m.measurement_id = self.global_measurement_id
158+
self.global_measurement_id += 1

src/bloqade/pyqrack/reg.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,20 @@
22
from typing import TYPE_CHECKING
33
from dataclasses import dataclass
44

5+
from bloqade.types import MeasurementResult
56
from bloqade.qasm2.types import Qubit
67

78
if TYPE_CHECKING:
89
from pyqrack import QrackSimulator
910

1011

11-
class Measurement(enum.IntEnum):
12+
class Measurement(MeasurementResult, enum.IntEnum):
1213
"""Enumeration of measurement results."""
1314

15+
def __init__(self, measurement_id: int = 0) -> None:
16+
super().__init__()
17+
self.measurement_id = measurement_id
18+
1419
Zero = 0
1520
One = 1
1621
Lost = enum.auto()

src/bloqade/pyqrack/squin/noise/native.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def n_sites(self) -> int:
4646
return 1
4747

4848
def apply(self, qubit: PyQrackQubit, adjoint: bool = False) -> None:
49-
if random.uniform(0.0, 1.0) < self.p:
49+
if random.uniform(0.0, 1.0) <= self.p:
5050
qubit.state = QubitState.Lost
5151

5252

src/bloqade/pyqrack/squin/qubit.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from kirin.dialects import ilist
55

66
from bloqade.squin import qubit
7-
from bloqade.pyqrack.reg import QubitState, PyQrackQubit
7+
from bloqade.pyqrack.reg import QubitState, Measurement, PyQrackQubit
88
from bloqade.pyqrack.base import PyQrackInterpreter
99

1010
from .runtime import OperatorRuntimeABC
@@ -41,9 +41,12 @@ def broadcast(
4141

4242
def _measure_qubit(self, qbit: PyQrackQubit, interp: PyQrackInterpreter):
4343
if qbit.is_active():
44-
return bool(qbit.sim_reg.m(qbit.addr))
44+
m = Measurement(bool(qbit.sim_reg.m(qbit.addr)))
4545
else:
46-
return interp.loss_m_result
46+
m = Measurement(interp.loss_m_result)
47+
48+
interp.set_global_measurement_id(m)
49+
return m
4750

4851
@interp.impl(qubit.MeasureQubitList)
4952
def measure_qubit_list(
@@ -63,3 +66,17 @@ def measure_qubit(
6366
qbit: PyQrackQubit = frame.get(stmt.qubit)
6467
result = self._measure_qubit(qbit, interp)
6568
return (result,)
69+
70+
@interp.impl(qubit.QubitId)
71+
def qubit_id(
72+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.QubitId
73+
):
74+
qbit: PyQrackQubit = frame.get(stmt.qubit)
75+
return (qbit.addr,)
76+
77+
@interp.impl(qubit.MeasurementId)
78+
def measurement_id(
79+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasurementId
80+
):
81+
measurement: Measurement = frame.get(stmt.measurement)
82+
return (measurement.measurement_id,)

src/bloqade/squin/qubit.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,20 @@ class MeasureQubitList(ir.Statement):
7979
result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType])
8080

8181

82+
@statement(dialect=dialect)
83+
class QubitId(ir.Statement):
84+
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
85+
qubit: ir.SSAValue = info.argument(QubitType)
86+
result: ir.ResultValue = info.result(types.Int)
87+
88+
89+
@statement(dialect=dialect)
90+
class MeasurementId(ir.Statement):
91+
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
92+
measurement: ir.SSAValue = info.argument(MeasurementResultType)
93+
result: ir.ResultValue = info.result(types.Int)
94+
95+
8296
# NOTE: no dependent types in Python, so we have to mark it Any...
8397
@wraps(New)
8498
def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
@@ -127,9 +141,10 @@ def measure(input: Any) -> Any:
127141
input: A qubit or a list of qubits to measure.
128142
129143
Returns:
130-
bool | list[bool]: The result of the measurement. If a single qubit is measured,
131-
a single boolean is returned. If a list of qubits is measured, a list of booleans
144+
MeasurementResult | list[MeasurementResult]: The result of the measurement. If a single qubit is measured,
145+
a single result is returned. If a list of qubits is measured, a list of results
132146
is returned.
147+
A MeasurementResult can represent both 0 and 1, but also atoms that are lost.
133148
"""
134149
...
135150

@@ -169,3 +184,11 @@ def ghz():
169184
None
170185
"""
171186
...
187+
188+
189+
@wraps(QubitId)
190+
def get_qubit_id(qubit: Qubit) -> int: ...
191+
192+
193+
@wraps(MeasurementId)
194+
def get_measurement_id(measurement: MeasurementResult) -> int: ...

test/squin/test_qubit.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from kirin import types
2+
3+
from bloqade import squin
4+
from bloqade.pyqrack import StackMemorySimulator
5+
from bloqade.pyqrack.reg import Measurement
6+
7+
8+
def test_get_ids():
9+
@squin.kernel
10+
def main():
11+
q = squin.qubit.new(3)
12+
13+
m = squin.qubit.measure(q)
14+
15+
qid = squin.qubit.get_qubit_id(q[0])
16+
mid = squin.qubit.get_measurement_id(m[1])
17+
return mid + qid
18+
19+
main.print()
20+
assert main.return_type.is_subseteq(types.Int)
21+
22+
@squin.kernel
23+
def main2():
24+
q = squin.qubit.new(2)
25+
26+
qid = squin.qubit.get_qubit_id(q[0])
27+
m1 = squin.qubit.measure(q[qid])
28+
29+
squin.gate.x(q[qid])
30+
m2 = squin.qubit.measure(q[qid])
31+
32+
m1_id = squin.qubit.get_measurement_id(m1)
33+
m2_id = squin.qubit.get_measurement_id(m2)
34+
35+
if m1_id != 0:
36+
# do something that errors
37+
squin.gate.x(q[4])
38+
39+
if m2_id != 1:
40+
squin.gate.x(q[4])
41+
42+
return squin.qubit.measure(q)
43+
44+
sim = StackMemorySimulator(min_qubits=2)
45+
result = sim.run(main2)
46+
for i, res in enumerate(result):
47+
assert isinstance(res, Measurement)
48+
assert res.measurement_id == i + 2

0 commit comments

Comments
 (0)