Skip to content

Commit 7885984

Browse files
committed
Also make QubitId and MeasurementId adhere to broadcasting semantics
1 parent eb30900 commit 7885984

File tree

5 files changed

+47
-13
lines changed

5 files changed

+47
-13
lines changed

src/bloqade/pyqrack/squin/qubit.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,17 @@ def measure_qubit_list(
4242
def qubit_id(
4343
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.QubitId
4444
):
45-
qbit: PyQrackQubit = frame.get(stmt.qubit)
46-
return (qbit.addr,)
45+
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
46+
ids = ilist.IList([qbit.addr for qbit in qubits])
47+
return (ids,)
4748

4849
@interp.impl(qubit.MeasurementId)
4950
def measurement_id(
5051
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasurementId
5152
):
52-
measurement: Measurement = frame.get(stmt.measurement)
53-
return (measurement.measurement_id,)
53+
measurements: ilist.IList[Measurement, Any] = frame.get(stmt.measurements)
54+
ids = ilist.IList([measurement.measurement_id for measurement in measurements])
55+
return (ids,)
5456

5557
@interp.impl(qubit.Reset)
5658
def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Reset):

src/bloqade/qubit/_interface.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ def measure(qubits: ilist.IList[Qubit, N]) -> ilist.IList[MeasurementResult, N]:
3636

3737

3838
@wraps(QubitId)
39-
def get_qubit_id(qubit: Qubit) -> int: ...
39+
def get_qubit_id(qubits: ilist.IList[Qubit, N]) -> ilist.IList[int, N]: ...
4040

4141

4242
@wraps(MeasurementId)
43-
def get_measurement_id(measurement: MeasurementResult) -> int: ...
43+
def get_measurement_id(
44+
measurements: ilist.IList[MeasurementResult, N],
45+
) -> ilist.IList[int, N]: ...
4446

4547

4648
@wraps(Reset)

src/bloqade/qubit/stdlib/broadcast.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,30 @@ def measure(qubits: ilist.IList[Qubit, N]) -> ilist.IList[MeasurementResult, N]:
3333
A MeasurementResult can represent both 0 and 1, but also atoms that are lost.
3434
"""
3535
return _qubit.measure(qubits)
36+
37+
38+
@kernel
39+
def get_qubit_id(qubits: ilist.IList[Qubit, N]) -> ilist.IList[int, N]:
40+
"""Get the global, unique ID of each qubit in the list.
41+
42+
Args:
43+
qubits (IList[Qubit, N]): The list of qubits of which you want the ID.
44+
45+
Returns:
46+
qubit_ids (IList[int, N]): The list of global, unique IDs of the qubits.
47+
"""
48+
return _qubit.get_qubit_id(qubits)
49+
50+
51+
@kernel
52+
def get_measurement_id(
53+
measurements: ilist.IList[MeasurementResult, N],
54+
) -> ilist.IList[int, N]:
55+
"""Get the global, unique ID of each of the measurement results in the list.
56+
57+
Args:
58+
measurements (IList[MeasurementResult, N]): The previously taken measurement of which you want to know the ID.
59+
Returns:
60+
measurement_ids (IList[int, N]): The list of global, unique IDs of the measurements.
61+
"""
62+
return _qubit.get_measurement_id(measurements)

src/bloqade/qubit/stdlib/simple.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from bloqade.types import Qubit, MeasurementResult
44

55
from . import broadcast
6-
from .. import _interface as _qubit
76
from .._prelude import kernel
87

98

@@ -43,7 +42,8 @@ def get_qubit_id(qubit: Qubit) -> int:
4342
Returns:
4443
qubit_id (int): The global, unique ID of the qubit.
4544
"""
46-
return _qubit.get_qubit_id(qubit)
45+
ids = broadcast.get_qubit_id(ilist.IList([qubit]))
46+
return ids[0]
4747

4848

4949
@kernel
@@ -55,4 +55,5 @@ def get_measurement_id(measurement: MeasurementResult) -> int:
5555
Returns:
5656
measurement_id (int): The global, unique ID of the measurement.
5757
"""
58-
return _qubit.get_measurement_id(measurement)
58+
ids = broadcast.get_measurement_id(ilist.IList([measurement]))
59+
return ids[0]

src/bloqade/qubit/stmts.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,17 @@ class Measure(ir.Statement):
2626
@statement(dialect=dialect)
2727
class QubitId(ir.Statement):
2828
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
29-
qubit: ir.SSAValue = info.argument(QubitType)
30-
result: ir.ResultValue = info.result(types.Int)
29+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
30+
result: ir.ResultValue = info.result(ilist.IListType[types.Int, Len])
3131

3232

3333
@statement(dialect=dialect)
3434
class MeasurementId(ir.Statement):
3535
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
36-
measurement: ir.SSAValue = info.argument(MeasurementResultType)
37-
result: ir.ResultValue = info.result(types.Int)
36+
measurements: ir.SSAValue = info.argument(
37+
ilist.IListType[MeasurementResultType, Len]
38+
)
39+
result: ir.ResultValue = info.result(ilist.IListType[types.Int, Len])
3840

3941

4042
@statement(dialect=dialect)

0 commit comments

Comments
 (0)