diff --git a/src/bloqade/pyqrack/qasm2/core.py b/src/bloqade/pyqrack/qasm2/core.py index 45f4baa1..aacad31d 100644 --- a/src/bloqade/pyqrack/qasm2/core.py +++ b/src/bloqade/pyqrack/qasm2/core.py @@ -51,12 +51,30 @@ def creg_get( def measure( self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.Measure ): - qarg: PyQrackQubit = frame.get(stmt.qarg) - carg: CBitRef = frame.get(stmt.carg) - if qarg.is_active(): - carg.set_value(Measurement(qarg.sim_reg.m(qarg.addr))) + qarg: PyQrackQubit | PyQrackReg = frame.get(stmt.qarg) + carg: CBitRef | CRegister = frame.get(stmt.carg) + + if isinstance(qarg, PyQrackQubit) and isinstance(carg, CBitRef): + if qarg.is_active(): + carg.set_value(Measurement(qarg.sim_reg.m(qarg.addr))) + else: + carg.set_value(interp.loss_m_result) + elif isinstance(qarg, PyQrackReg) and isinstance(carg, CRegister): + # TODO: clean up iteration after PyQrackReg is refactored + for i in range(qarg.size): + qubit = qarg[i] + + # TODO: make this consistent with PyQrackReg __getitem__ ? + cbit = CBitRef(carg, i) + + if qubit.is_active(): + cbit.set_value(Measurement(qarg.sim_reg.m(qubit.addr))) + else: + cbit.set_value(interp.loss_m_result) else: - carg.set_value(interp.loss_m_result) + raise RuntimeError( + f"Expected measure call on either a single qubit and classical bit, or two registers, but got the types {type(qarg)} and {type(carg)}" + ) return () diff --git a/src/bloqade/qasm2/_wrappers.py b/src/bloqade/qasm2/_wrappers.py index 4fed2a8b..79e30cba 100644 --- a/src/bloqade/qasm2/_wrappers.py +++ b/src/bloqade/qasm2/_wrappers.py @@ -1,3 +1,5 @@ +from typing import overload + from kirin.lowering import wraps from .types import Bit, CReg, QReg, Qubit @@ -58,7 +60,7 @@ def reset(qarg: Qubit) -> None: ... -@wraps(core.Measure) +@overload def measure(qarg: Qubit, cbit: Bit) -> None: """ Measure the qubit `qarg` and store the result in the classical bit `cbit`. @@ -70,6 +72,22 @@ def measure(qarg: Qubit, cbit: Bit) -> None: ... +@overload +def measure(qarg: QReg, carg: CReg) -> None: + """ + Measure each qubit in the quantum register `qarg` and store the result in the classical register `carg`. + + Args: + qarg: The quantum register to measure. + carg: The classical bit to store the result in. + """ + ... + + +@wraps(core.Measure) +def measure(qarg, carg) -> None: ... + + @wraps(uop.CX) def cx(ctrl: Qubit, qarg: Qubit) -> None: """ diff --git a/src/bloqade/qasm2/dialects/core/stmts.py b/src/bloqade/qasm2/dialects/core/stmts.py index 7231bada..75fec223 100644 --- a/src/bloqade/qasm2/dialects/core/stmts.py +++ b/src/bloqade/qasm2/dialects/core/stmts.py @@ -46,10 +46,21 @@ class Measure(ir.Statement): name = "measure" traits = frozenset({lowering.FromPythonCall()}) - qarg: ir.SSAValue = info.argument(QubitType) - """qarg (Qubit): The qubit to measure.""" - carg: ir.SSAValue = info.argument(BitType) - """carg (Bit): The bit to store the result in.""" + qarg: ir.SSAValue = info.argument(QubitType | QRegType) + """qarg (Qubit | QReg): The qubit or quantum register to measure.""" + carg: ir.SSAValue = info.argument(BitType | CRegType) + """carg (Bit | CReg): The bit or register to store the result in.""" + + def check_type(self) -> None: + qarg_is_qubit = self.qarg.type.is_subseteq(QubitType) + carg_is_bit = self.carg.type.is_subseteq(BitType) + if (qarg_is_qubit and not carg_is_bit) or (not qarg_is_qubit and carg_is_bit): + raise ir.TypeCheckError( + self, + "Can't perform measurement with single (qu)bit and an entire register!", + help="Instead of `measure(qreg[i], creg)` or `measure(qreg, creg[i])`" + "use `measure(qreg[i], creg[j])` or `measure(qreg, creg)`, respectively.", + ) @statement(dialect=dialect) diff --git a/test/pyqrack/test_target.py b/test/pyqrack/test_target.py index 58d3413e..56156dad 100644 --- a/test/pyqrack/test_target.py +++ b/test/pyqrack/test_target.py @@ -1,5 +1,8 @@ import math +import pytest +from kirin import ir + from bloqade import qasm2 from bloqade.pyqrack import PyQrack, reg @@ -111,4 +114,37 @@ def multiple_registers(): assert True -test_target_glob() +def test_measurement(): + + @qasm2.main + def measure_register(): + q = qasm2.qreg(2) + c = qasm2.creg(2) + qasm2.x(q[0]) + qasm2.cx(q[0], q[1]) + qasm2.measure(q, c) + return c + + @qasm2.main + def measure_single_qubits(): + q = qasm2.qreg(2) + c = qasm2.creg(2) + qasm2.x(q[0]) + qasm2.cx(q[0], q[1]) + qasm2.measure(q[0], c[0]) + qasm2.measure(q[1], c[1]) + return c + + target = PyQrack(2) + result_single = target.run(measure_single_qubits) + result_reg = target.run(measure_register) + + assert result_single == result_reg == [reg.Measurement.One, reg.Measurement.One] + + with pytest.raises(ir.ValidationError): + + @qasm2.main + def measurement_that_errors(): + q = qasm2.qreg(1) + c = qasm2.creg(1) + qasm2.measure(q[0], c)