Skip to content

Commit db85cf0

Browse files
authored
predicates for MeasurementResult (#599)
1 parent 6351b5f commit db85cf0

File tree

18 files changed

+440
-79
lines changed

18 files changed

+440
-79
lines changed

src/bloqade/analysis/measure_id/analysis.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from typing import TypeVar
21
from dataclasses import field, dataclass
32

43
from kirin import ir
5-
from kirin.analysis import ForwardExtra, const
4+
from kirin.analysis import ForwardExtra
65
from kirin.analysis.forward import ForwardFrame
76

87
from .lattice import MeasureId, NotMeasureId
@@ -33,22 +32,5 @@ def eval_fallback(
3332
) -> tuple[MeasureId, ...]:
3433
return tuple(NotMeasureId() for _ in node.results)
3534

36-
# Xiu-zhe (Roger) Luo came up with this in the address analysis,
37-
# reused here for convenience (now modified to be a bit more graceful)
38-
# TODO: Remove this function once upgrade to kirin 0.18 happens,
39-
# method is built-in to interpreter then
40-
41-
T = TypeVar("T")
42-
43-
def get_const_value(
44-
self, input_type: type[T] | tuple[type[T], ...], value: ir.SSAValue
45-
) -> type[T] | None:
46-
if isinstance(hint := value.hints.get("const"), const.Value):
47-
data = hint.data
48-
if isinstance(data, input_type):
49-
return hint.data
50-
51-
return None
52-
5335
def method_self(self, method: ir.Method) -> MeasureId:
5436
return self.lattice.bottom()

src/bloqade/analysis/measure_id/impls.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,16 @@
55
from bloqade import qubit, annotate
66

77
from .lattice import (
8+
Predicate,
89
AnyMeasureId,
910
NotMeasureId,
11+
RawMeasureId,
1012
MeasureIdBool,
1113
MeasureIdTuple,
1214
InvalidMeasureId,
1315
)
1416
from .analysis import MeasureIDFrame, MeasurementIDAnalysis
1517

16-
## Can't do wire right now because of
17-
## unresolved RFC on return type
18-
# from bloqade.squin import wire
19-
2018

2119
@qubit.dialect.register(key="measure_id")
2220
class SquinQubit(interp.MethodTable):
@@ -30,7 +28,6 @@ def measure_qubit_list(
3028
):
3129

3230
# try to get the length of the list
33-
## "...safely assume the type inference will give you what you need"
3431
qubits_type = stmt.qubits.type
3532
# vars[0] is just the type of the elements in the ilist,
3633
# vars[1] can contain a literal with length information
@@ -41,10 +38,41 @@ def measure_qubit_list(
4138
measure_id_bools = []
4239
for _ in range(num_qubits.data):
4340
interp.measure_count += 1
44-
measure_id_bools.append(MeasureIdBool(interp.measure_count))
41+
measure_id_bools.append(RawMeasureId(interp.measure_count))
4542

4643
return (MeasureIdTuple(data=tuple(measure_id_bools)),)
4744

45+
@interp.impl(qubit.stmts.IsLost)
46+
@interp.impl(qubit.stmts.IsOne)
47+
@interp.impl(qubit.stmts.IsZero)
48+
def measurement_predicate(
49+
self,
50+
interp: MeasurementIDAnalysis,
51+
frame: interp.Frame,
52+
stmt: qubit.stmts.IsLost | qubit.stmts.IsOne | qubit.stmts.IsZero,
53+
):
54+
original_measure_id_tuple = frame.get(stmt.measurements)
55+
if not all(
56+
isinstance(measure_id, RawMeasureId)
57+
for measure_id in original_measure_id_tuple.data
58+
):
59+
return (InvalidMeasureId(),)
60+
61+
if isinstance(stmt, qubit.stmts.IsLost):
62+
predicate = Predicate.IS_LOST
63+
elif isinstance(stmt, qubit.stmts.IsOne):
64+
predicate = Predicate.IS_ONE
65+
elif isinstance(stmt, qubit.stmts.IsZero):
66+
predicate = Predicate.IS_ZERO
67+
else:
68+
return (InvalidMeasureId(),)
69+
70+
predicate_measure_ids = [
71+
MeasureIdBool(measure_id.idx, predicate)
72+
for measure_id in original_measure_id_tuple.data
73+
]
74+
return (MeasureIdTuple(data=tuple(predicate_measure_ids)),)
75+
4876

4977
@annotate.dialect.register(key="measure_id")
5078
class Annotate(interp.MethodTable):
@@ -94,14 +122,10 @@ def getitem(
94122
self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.GetItem
95123
):
96124

97-
idx_or_slice = interp.get_const_value((int, slice), stmt.index)
125+
idx_or_slice = interp.maybe_const(stmt.index, (int, slice))
98126
if idx_or_slice is None:
99127
return (InvalidMeasureId(),)
100128

101-
# hint = stmt.index.hints.get("const")
102-
# if hint is None or not isinstance(hint, const.Value):
103-
# return (InvalidMeasureId(),)
104-
105129
obj = frame.get(stmt.obj)
106130
if isinstance(obj, MeasureIdTuple):
107131
if isinstance(idx_or_slice, slice):

src/bloqade/analysis/measure_id/lattice.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
from typing import final
23
from dataclasses import dataclass
34

@@ -8,8 +9,21 @@
89
SimpleMeetMixin,
910
)
1011

12+
13+
class Predicate(Enum):
14+
IS_ZERO = 1
15+
IS_ONE = 2
16+
IS_LOST = 3
17+
18+
def __str__(self) -> str:
19+
return self.name
20+
21+
def __repr__(self) -> str:
22+
return self.name
23+
24+
1125
# Taken directly from Kai-Hsin Wu's implementation
12-
# with minor changes to names and addition of CanMeasureId type
26+
# with minor changes to names
1327

1428

1529
@dataclass
@@ -57,18 +71,25 @@ def is_subseteq(self, other: MeasureId) -> bool:
5771

5872
@final
5973
@dataclass
60-
class MeasureIdBool(MeasureId):
74+
class RawMeasureId(MeasureId):
6175
idx: int
6276

6377
def is_subseteq(self, other: MeasureId) -> bool:
64-
if isinstance(other, MeasureIdBool):
78+
if isinstance(other, RawMeasureId):
6579
return self.idx == other.idx
6680
return False
6781

6882

69-
# Might be nice to have some print override
70-
# here so all the CanMeasureId's/other types are consolidated for
71-
# readability
83+
@final
84+
@dataclass
85+
class MeasureIdBool(MeasureId):
86+
idx: int
87+
predicate: Predicate
88+
89+
def is_subseteq(self, other: MeasureId) -> bool:
90+
if isinstance(other, MeasureIdBool):
91+
return self.predicate == other.predicate and self.idx == other.idx
92+
return False
7293

7394

7495
@final

src/bloqade/qubit/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from ._prelude import kernel as kernel
77
from .stdlib.simple import (
88
reset as reset,
9+
is_one as is_one,
10+
is_lost as is_lost,
11+
is_zero as is_zero,
912
measure as measure,
1013
get_qubit_id as get_qubit_id,
1114
get_measurement_id as get_measurement_id,

src/bloqade/qubit/_interface.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from bloqade.types import Qubit, MeasurementResult
77

8-
from .stmts import New, Reset, Measure, QubitId, MeasurementId
8+
from .stmts import New, IsOne, Reset, IsLost, IsZero, Measure, QubitId, MeasurementId
99

1010

1111
@wraps(New)
@@ -47,3 +47,51 @@ def get_measurement_id(
4747

4848
@wraps(Reset)
4949
def reset(qubits: ilist.IList[Qubit, Any]) -> None: ...
50+
51+
52+
@wraps(IsZero)
53+
def is_zero(
54+
measurements: ilist.IList[MeasurementResult, N],
55+
) -> ilist.IList[bool, N]:
56+
"""
57+
Check if each measurement result in a list corresponds to a measured value of 0.
58+
59+
Args:
60+
measurements (IList[MeasurementResult, N]): The list of measurements to check.
61+
62+
Returns:
63+
IList[bool, N]: A list of booleans indicating whether each measurement result is 0.
64+
"""
65+
66+
...
67+
68+
69+
@wraps(IsOne)
70+
def is_one(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
71+
"""
72+
Check if each measurement result in a list corresponds to a measured value of 1.
73+
74+
Args:
75+
measurements (IList[MeasurementResult, N]): The list of measurements to check.
76+
77+
Returns:
78+
IList[bool, N]: A list of booleans indicating whether each measurement result is 1.
79+
"""
80+
...
81+
82+
83+
@wraps(IsLost)
84+
def is_lost(
85+
measurements: ilist.IList[MeasurementResult, N],
86+
) -> ilist.IList[bool, N]:
87+
"""
88+
Check if each measurement result in a list corresponds to a lost atom.
89+
90+
Args:
91+
measurements (IList[MeasurementResult, N]): The list of measurements to check.
92+
93+
Returns:
94+
IList[bool, N]: A list of booleans indicating whether each measurement indicates the atom was lost.
95+
96+
"""
97+
...

src/bloqade/qubit/stdlib/broadcast.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,39 @@ def get_measurement_id(
6060
measurement_ids (IList[int, N]): The list of global, unique IDs of the measurements.
6161
"""
6262
return _qubit.get_measurement_id(measurements)
63+
64+
65+
@kernel
66+
def is_zero(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
67+
"""Check if each MeasurementResult in the list is equivalent to measuring the zero state.
68+
69+
Args:
70+
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
71+
Returns:
72+
IList[bool, N]: A list of booleans indicating whether each MeasurementResult is equivalent to the zero state.
73+
"""
74+
return _qubit.is_zero(measurements)
75+
76+
77+
@kernel
78+
def is_one(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
79+
"""Check if each MeasurementResult in the list is equivalent to measuring the one state.
80+
81+
Args:
82+
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
83+
Returns:
84+
IList[bool, N]: A list of booleans indicating whether each MeasurementResult is equivalent to the one state.
85+
"""
86+
return _qubit.is_one(measurements)
87+
88+
89+
@kernel
90+
def is_lost(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
91+
"""Check if each MeasurementResult in the list indicates atom loss.
92+
93+
Args:
94+
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
95+
Returns:
96+
IList[bool, N]: A list of booleans indicating whether each MeasurementResult indicates atom loss.
97+
"""
98+
return _qubit.is_lost(measurements)

src/bloqade/qubit/stdlib/simple.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,44 @@ def get_measurement_id(measurement: MeasurementResult) -> int:
5757
"""
5858
ids = broadcast.get_measurement_id(ilist.IList([measurement]))
5959
return ids[0]
60+
61+
62+
@kernel
63+
def is_zero(measurement: MeasurementResult) -> bool:
64+
"""Check if the measurement result is equivalent to measuring the zero state.
65+
66+
Args:
67+
measurement (MeasurementResult): The measurement result to check.
68+
Returns:
69+
bool: True if the measurement result is equivalent to measuring the zero state, False otherwise.
70+
71+
"""
72+
results = broadcast.is_zero(ilist.IList([measurement]))
73+
return results[0]
74+
75+
76+
@kernel
77+
def is_one(measurement: MeasurementResult) -> bool:
78+
"""Check if the measurement result is equivalent to measuring the one state.
79+
80+
Args:
81+
measurement (MeasurementResult): The measurement result to check.
82+
Returns:
83+
bool: True if the measurement result is equivalent to measuring the one state, False otherwise.
84+
85+
"""
86+
results = broadcast.is_one(ilist.IList([measurement]))
87+
return results[0]
88+
89+
90+
@kernel
91+
def is_lost(measurement: MeasurementResult) -> bool:
92+
"""Check if the measurement result indicates atom loss.
93+
94+
Args:
95+
measurement (MeasurementResult): The measurement result to check.
96+
Returns:
97+
bool: True if the measurement result indicates atom loss, False otherwise.
98+
"""
99+
results = broadcast.is_lost(ilist.IList([measurement]))
100+
return results[0]

src/bloqade/qubit/stmts.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,30 @@ class Reset(ir.Statement):
4545
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
4646

4747

48+
@statement
49+
class MeasurementPredicate(ir.Statement):
50+
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
51+
measurements: ir.SSAValue = info.argument(
52+
ilist.IListType[MeasurementResultType, Len]
53+
)
54+
result: ir.ResultValue = info.result(ilist.IListType[types.Bool, Len])
55+
56+
57+
@statement(dialect=dialect)
58+
class IsZero(MeasurementPredicate):
59+
pass
60+
61+
62+
@statement(dialect=dialect)
63+
class IsOne(MeasurementPredicate):
64+
pass
65+
66+
67+
@statement(dialect=dialect)
68+
class IsLost(MeasurementPredicate):
69+
pass
70+
71+
4872
# TODO: investigate why this is needed to get type inference to be correct.
4973
@dialect.register(key="typeinfer")
5074
class __TypeInfer(interp.MethodTable):

src/bloqade/squin/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from .. import qubit as qubit, annotate as annotate
77
from ..qubit import (
88
reset as reset,
9+
is_one as is_one,
910
qalloc as qalloc,
11+
is_lost as is_lost,
12+
is_zero as is_zero,
1013
measure as measure,
1114
get_qubit_id as get_qubit_id,
1215
get_measurement_id as get_measurement_id,

src/bloqade/squin/stdlib/broadcast/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,10 @@
3131
two_qubit_pauli_channel as two_qubit_pauli_channel,
3232
single_qubit_pauli_channel as single_qubit_pauli_channel,
3333
)
34-
from ._qubit import reset as reset, measure as measure
34+
from ._qubit import (
35+
reset as reset,
36+
is_one as is_one,
37+
is_lost as is_lost,
38+
is_zero as is_zero,
39+
measure as measure,
40+
)

0 commit comments

Comments
 (0)