Skip to content

Commit 600a48a

Browse files
committed
initial measurement predicate implementation
1 parent 1642980 commit 600a48a

File tree

15 files changed

+329
-39
lines changed

15 files changed

+329
-39
lines changed

src/bloqade/analysis/measure_id/impls.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
from bloqade import qubit
66

77
from .lattice import (
8+
Predicate,
89
AnyMeasureId,
910
NotMeasureId,
11+
RawMeasureId,
1012
MeasureIdBool,
1113
MeasureIdTuple,
1214
InvalidMeasureId,
@@ -41,10 +43,45 @@ def measure_qubit_list(
4143
measure_id_bools = []
4244
for _ in range(num_qubits.data):
4345
interp.measure_count += 1
44-
measure_id_bools.append(MeasureIdBool(interp.measure_count))
46+
measure_id_bools.append(RawMeasureId(interp.measure_count))
4547

4648
return (MeasureIdTuple(data=tuple(measure_id_bools)),)
4749

50+
@interp.impl(qubit.stmts.IsLost)
51+
@interp.impl(qubit.stmts.IsOne)
52+
@interp.impl(qubit.stmts.IsZero)
53+
def measurement_predicate(
54+
self,
55+
interp: MeasurementIDAnalysis,
56+
frame: interp.Frame,
57+
stmt: qubit.stmts.IsLost | qubit.stmts.IsOne | qubit.stmts.IsZero,
58+
):
59+
original_measure_id_tuple = frame.get(stmt.measurements)
60+
# all members should be RawMeasureId, if it's anything else
61+
# it's Invalid.
62+
if not all(
63+
isinstance(measure_id, RawMeasureId)
64+
for measure_id in original_measure_id_tuple.data
65+
):
66+
return (InvalidMeasureId(),)
67+
68+
# get the proper predicate type
69+
if isinstance(stmt, qubit.stmts.IsLost):
70+
predicate = Predicate.IS_LOST
71+
elif isinstance(stmt, qubit.stmts.IsOne):
72+
predicate = Predicate.IS_ONE
73+
elif isinstance(stmt, qubit.stmts.IsZero):
74+
predicate = Predicate.IS_ZERO
75+
else:
76+
return (InvalidMeasureId(),)
77+
78+
# Create new MeasureIdBools with proper predicate type
79+
predicate_measure_ids = [
80+
MeasureIdBool(measure_id.idx, predicate)
81+
for measure_id in original_measure_id_tuple.data
82+
]
83+
return (MeasureIdTuple(data=tuple(predicate_measure_ids)),)
84+
4885

4986
@ilist.dialect.register(key="measure_id")
5087
class IList(interp.MethodTable):

src/bloqade/analysis/measure_id/lattice.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
from typing import final
23
from dataclasses import dataclass
34

@@ -8,8 +9,15 @@
89
SimpleMeetMixin,
910
)
1011

12+
13+
class Predicate(Enum):
14+
IS_ZERO = 1
15+
IS_ONE = 2
16+
IS_LOST = 3
17+
18+
1119
# Taken directly from Kai-Hsin Wu's implementation
12-
# with minor changes to names and addition of CanMeasureId type
20+
# with minor changes to names
1321

1422

1523
@dataclass
@@ -57,18 +65,25 @@ def is_subseteq(self, other: MeasureId) -> bool:
5765

5866
@final
5967
@dataclass
60-
class MeasureIdBool(MeasureId):
68+
class RawMeasureId(MeasureId):
6169
idx: int
6270

6371
def is_subseteq(self, other: MeasureId) -> bool:
64-
if isinstance(other, MeasureIdBool):
72+
if isinstance(other, RawMeasureId):
6573
return self.idx == other.idx
6674
return False
6775

6876

69-
# Might be nice to have some print override
70-
# here so all the CanMeasureId's/other types are consolidated for
71-
# readability
77+
@final
78+
@dataclass
79+
class MeasureIdBool(MeasureId):
80+
idx: int
81+
predicate: Predicate
82+
83+
def is_subseteq(self, other: MeasureId) -> bool:
84+
if isinstance(other, MeasureIdBool):
85+
return self.predicate == other.predicate and self.idx == other.idx
86+
return False
7287

7388

7489
@final

src/bloqade/qubit/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from ._prelude import kernel as kernel
77
from .stdlib.simple import (
88
reset as reset,
9+
is_one as is_one,
10+
is_lost as is_lost,
11+
is_zero as is_zero,
912
measure as measure,
1013
get_qubit_id as get_qubit_id,
1114
get_measurement_id as get_measurement_id,

src/bloqade/qubit/_interface.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from bloqade.types import Qubit, MeasurementResult
77

8-
from .stmts import New, Reset, Measure, QubitId, MeasurementId
8+
from .stmts import New, IsOne, Reset, IsLost, IsZero, Measure, QubitId, MeasurementId
99

1010

1111
@wraps(New)
@@ -47,3 +47,19 @@ def get_measurement_id(
4747

4848
@wraps(Reset)
4949
def reset(qubits: ilist.IList[Qubit, Any]) -> None: ...
50+
51+
52+
@wraps(IsZero)
53+
def is_zero(
54+
measurements: ilist.IList[MeasurementResult, N],
55+
) -> ilist.IList[bool, N]: ...
56+
57+
58+
@wraps(IsOne)
59+
def is_one(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]: ...
60+
61+
62+
@wraps(IsLost)
63+
def is_lost(
64+
measurements: ilist.IList[MeasurementResult, N],
65+
) -> ilist.IList[bool, N]: ...

src/bloqade/qubit/stdlib/broadcast.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,39 @@ def get_measurement_id(
6060
measurement_ids (IList[int, N]): The list of global, unique IDs of the measurements.
6161
"""
6262
return _qubit.get_measurement_id(measurements)
63+
64+
65+
@kernel
66+
def is_zero(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
67+
"""Check if each MeasurementResult in the list is equivalent to measuring the zero state.
68+
69+
Args:
70+
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
71+
Returns:
72+
IList[bool, N]: A list of booleans indicating whether each MeasurementResult is equivalent to the zero state.
73+
"""
74+
return _qubit.is_zero(measurements)
75+
76+
77+
@kernel
78+
def is_one(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
79+
"""Check if each MeasurementResult in the list is equivalent to measuring the one state.
80+
81+
Args:
82+
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
83+
Returns:
84+
IList[bool, N]: A list of booleans indicating whether each MeasurementResult is equivalent to the one state.
85+
"""
86+
return _qubit.is_one(measurements)
87+
88+
89+
@kernel
90+
def is_lost(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
91+
"""Check if each MeasurementResult in the list indicates atom loss.
92+
93+
Args:
94+
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
95+
Returns:
96+
IList[bool, N]: A list of booleans indicating whether each MeasurementResult indicates atom loss.
97+
"""
98+
return _qubit.is_lost(measurements)

src/bloqade/qubit/stdlib/simple.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,44 @@ def get_measurement_id(measurement: MeasurementResult) -> int:
5757
"""
5858
ids = broadcast.get_measurement_id(ilist.IList([measurement]))
5959
return ids[0]
60+
61+
62+
@kernel
63+
def is_zero(measurement: MeasurementResult) -> bool:
64+
"""Check if the measurement result is equivalent to measuring the zero state.
65+
66+
Args:
67+
measurement (MeasurementResult): The measurement result to check.
68+
Returns:
69+
bool: True if the measurement result is equivalent to measuring the zero state, False otherwise.
70+
71+
"""
72+
results = broadcast.is_zero(ilist.IList([measurement]))
73+
return results[0]
74+
75+
76+
@kernel
77+
def is_one(measurement: MeasurementResult) -> bool:
78+
"""Check if the measurement result is equivalent to measuring the one state.
79+
80+
Args:
81+
measurement (MeasurementResult): The measurement result to check.
82+
Returns:
83+
bool: True if the measurement result is equivalent to measuring the one state, False otherwise.
84+
85+
"""
86+
results = broadcast.is_one(ilist.IList([measurement]))
87+
return results[0]
88+
89+
90+
@kernel
91+
def is_lost(measurement: MeasurementResult) -> bool:
92+
"""Check if the measurement result indicates atom loss.
93+
94+
Args:
95+
measurement (MeasurementResult): The measurement result to check.
96+
Returns:
97+
bool: True if the measurement result indicates atom loss, False otherwise.
98+
"""
99+
results = broadcast.is_lost(ilist.IList([measurement]))
100+
return results[0]

src/bloqade/qubit/stmts.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,30 @@ class Reset(ir.Statement):
4545
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
4646

4747

48+
@statement
49+
class MeasurementPredicate(ir.Statement):
50+
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
51+
measurements: ir.SSAValue = info.argument(
52+
ilist.IListType[MeasurementResultType, Len]
53+
)
54+
result: ir.ResultValue = info.result(ilist.IListType[types.Bool, Len])
55+
56+
57+
@statement(dialect=dialect)
58+
class IsZero(MeasurementPredicate):
59+
pass
60+
61+
62+
@statement(dialect=dialect)
63+
class IsOne(MeasurementPredicate):
64+
pass
65+
66+
67+
@statement(dialect=dialect)
68+
class IsLost(MeasurementPredicate):
69+
pass
70+
71+
4872
# TODO: investigate why this is needed to get type inference to be correct.
4973
@dialect.register(key="typeinfer")
5074
class __TypeInfer(interp.MethodTable):

src/bloqade/squin/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from .. import qubit as qubit
77
from ..qubit import (
88
reset as reset,
9+
is_one as is_one,
910
qalloc as qalloc,
11+
is_lost as is_lost,
12+
is_zero as is_zero,
1013
measure as measure,
1114
get_qubit_id as get_qubit_id,
1215
get_measurement_id as get_measurement_id,

src/bloqade/squin/stdlib/broadcast/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,10 @@
3131
two_qubit_pauli_channel as two_qubit_pauli_channel,
3232
single_qubit_pauli_channel as single_qubit_pauli_channel,
3333
)
34-
from ._qubit import reset as reset, measure as measure
34+
from ._qubit import (
35+
reset as reset,
36+
is_one as is_one,
37+
is_lost as is_lost,
38+
is_zero as is_zero,
39+
measure as measure,
40+
)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from bloqade.qubit.stdlib.broadcast import (
22
reset as reset,
3+
is_one as is_one,
4+
is_lost as is_lost,
5+
is_zero as is_zero,
36
measure as measure,
47
)

0 commit comments

Comments
 (0)