Skip to content

Commit 2e9518c

Browse files
authored
Support measuring registers in pyqrack (#174)
* Support measuring registers in pyqrack * Make tests deterministic * Properly overload wrapper * Fix docstrings of overloaded wrappers * Add type check to Measure stmt and test * Update test to kirin v0.16.8 * Fix Measure stmt docstring * Fix CI by checking for ValidationError
1 parent 1cd01ec commit 2e9518c

File tree

4 files changed

+94
-11
lines changed

4 files changed

+94
-11
lines changed

src/bloqade/pyqrack/qasm2/core.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,30 @@ def creg_get(
5151
def measure(
5252
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.Measure
5353
):
54-
qarg: PyQrackQubit = frame.get(stmt.qarg)
55-
carg: CBitRef = frame.get(stmt.carg)
56-
if qarg.is_active():
57-
carg.set_value(Measurement(qarg.sim_reg.m(qarg.addr)))
54+
qarg: PyQrackQubit | PyQrackReg = frame.get(stmt.qarg)
55+
carg: CBitRef | CRegister = frame.get(stmt.carg)
56+
57+
if isinstance(qarg, PyQrackQubit) and isinstance(carg, CBitRef):
58+
if qarg.is_active():
59+
carg.set_value(Measurement(qarg.sim_reg.m(qarg.addr)))
60+
else:
61+
carg.set_value(interp.loss_m_result)
62+
elif isinstance(qarg, PyQrackReg) and isinstance(carg, CRegister):
63+
# TODO: clean up iteration after PyQrackReg is refactored
64+
for i in range(qarg.size):
65+
qubit = qarg[i]
66+
67+
# TODO: make this consistent with PyQrackReg __getitem__ ?
68+
cbit = CBitRef(carg, i)
69+
70+
if qubit.is_active():
71+
cbit.set_value(Measurement(qarg.sim_reg.m(qubit.addr)))
72+
else:
73+
cbit.set_value(interp.loss_m_result)
5874
else:
59-
carg.set_value(interp.loss_m_result)
75+
raise RuntimeError(
76+
f"Expected measure call on either a single qubit and classical bit, or two registers, but got the types {type(qarg)} and {type(carg)}"
77+
)
6078

6179
return ()
6280

src/bloqade/qasm2/_wrappers.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import overload
2+
13
from kirin.lowering import wraps
24

35
from .types import Bit, CReg, QReg, Qubit
@@ -58,7 +60,7 @@ def reset(qarg: Qubit) -> None:
5860
...
5961

6062

61-
@wraps(core.Measure)
63+
@overload
6264
def measure(qarg: Qubit, cbit: Bit) -> None:
6365
"""
6466
Measure the qubit `qarg` and store the result in the classical bit `cbit`.
@@ -70,6 +72,22 @@ def measure(qarg: Qubit, cbit: Bit) -> None:
7072
...
7173

7274

75+
@overload
76+
def measure(qarg: QReg, carg: CReg) -> None:
77+
"""
78+
Measure each qubit in the quantum register `qarg` and store the result in the classical register `carg`.
79+
80+
Args:
81+
qarg: The quantum register to measure.
82+
carg: The classical bit to store the result in.
83+
"""
84+
...
85+
86+
87+
@wraps(core.Measure)
88+
def measure(qarg, carg) -> None: ...
89+
90+
7391
@wraps(uop.CX)
7492
def cx(ctrl: Qubit, qarg: Qubit) -> None:
7593
"""

src/bloqade/qasm2/dialects/core/stmts.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,21 @@ class Measure(ir.Statement):
4646

4747
name = "measure"
4848
traits = frozenset({lowering.FromPythonCall()})
49-
qarg: ir.SSAValue = info.argument(QubitType)
50-
"""qarg (Qubit): The qubit to measure."""
51-
carg: ir.SSAValue = info.argument(BitType)
52-
"""carg (Bit): The bit to store the result in."""
49+
qarg: ir.SSAValue = info.argument(QubitType | QRegType)
50+
"""qarg (Qubit | QReg): The qubit or quantum register to measure."""
51+
carg: ir.SSAValue = info.argument(BitType | CRegType)
52+
"""carg (Bit | CReg): The bit or register to store the result in."""
53+
54+
def check_type(self) -> None:
55+
qarg_is_qubit = self.qarg.type.is_subseteq(QubitType)
56+
carg_is_bit = self.carg.type.is_subseteq(BitType)
57+
if (qarg_is_qubit and not carg_is_bit) or (not qarg_is_qubit and carg_is_bit):
58+
raise ir.TypeCheckError(
59+
self,
60+
"Can't perform measurement with single (qu)bit and an entire register!",
61+
help="Instead of `measure(qreg[i], creg)` or `measure(qreg, creg[i])`"
62+
"use `measure(qreg[i], creg[j])` or `measure(qreg, creg)`, respectively.",
63+
)
5364

5465

5566
@statement(dialect=dialect)

test/pyqrack/test_target.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import math
22

3+
import pytest
4+
from kirin import ir
5+
36
from bloqade import qasm2
47
from bloqade.pyqrack import PyQrack, reg
58

@@ -111,4 +114,37 @@ def multiple_registers():
111114
assert True
112115

113116

114-
test_target_glob()
117+
def test_measurement():
118+
119+
@qasm2.main
120+
def measure_register():
121+
q = qasm2.qreg(2)
122+
c = qasm2.creg(2)
123+
qasm2.x(q[0])
124+
qasm2.cx(q[0], q[1])
125+
qasm2.measure(q, c)
126+
return c
127+
128+
@qasm2.main
129+
def measure_single_qubits():
130+
q = qasm2.qreg(2)
131+
c = qasm2.creg(2)
132+
qasm2.x(q[0])
133+
qasm2.cx(q[0], q[1])
134+
qasm2.measure(q[0], c[0])
135+
qasm2.measure(q[1], c[1])
136+
return c
137+
138+
target = PyQrack(2)
139+
result_single = target.run(measure_single_qubits)
140+
result_reg = target.run(measure_register)
141+
142+
assert result_single == result_reg == [reg.Measurement.One, reg.Measurement.One]
143+
144+
with pytest.raises(ir.ValidationError):
145+
146+
@qasm2.main
147+
def measurement_that_errors():
148+
q = qasm2.qreg(1)
149+
c = qasm2.creg(1)
150+
qasm2.measure(q[0], c)

0 commit comments

Comments
 (0)