Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion src/bloqade/analysis/measure_id/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from bloqade import qubit, annotate

from .lattice import (
Predicate,
AnyMeasureId,
NotMeasureId,
RawMeasureId,
MeasureIdBool,
MeasureIdTuple,
InvalidMeasureId,
Expand Down Expand Up @@ -41,10 +43,45 @@ def measure_qubit_list(
measure_id_bools = []
for _ in range(num_qubits.data):
interp.measure_count += 1
measure_id_bools.append(MeasureIdBool(interp.measure_count))
measure_id_bools.append(RawMeasureId(interp.measure_count))

return (MeasureIdTuple(data=tuple(measure_id_bools)),)

@interp.impl(qubit.stmts.IsLost)
@interp.impl(qubit.stmts.IsOne)
@interp.impl(qubit.stmts.IsZero)
def measurement_predicate(
self,
interp: MeasurementIDAnalysis,
frame: interp.Frame,
stmt: qubit.stmts.IsLost | qubit.stmts.IsOne | qubit.stmts.IsZero,
):
original_measure_id_tuple = frame.get(stmt.measurements)
# all members should be RawMeasureId, if it's anything else
# it's Invalid.
if not all(
isinstance(measure_id, RawMeasureId)
for measure_id in original_measure_id_tuple.data
):
return (InvalidMeasureId(),)

# get the proper predicate type
if isinstance(stmt, qubit.stmts.IsLost):
predicate = Predicate.IS_LOST
elif isinstance(stmt, qubit.stmts.IsOne):
predicate = Predicate.IS_ONE
elif isinstance(stmt, qubit.stmts.IsZero):
predicate = Predicate.IS_ZERO
else:
return (InvalidMeasureId(),)

# Create new MeasureIdBools with proper predicate type
predicate_measure_ids = [
MeasureIdBool(measure_id.idx, predicate)
for measure_id in original_measure_id_tuple.data
]
return (MeasureIdTuple(data=tuple(predicate_measure_ids)),)


@annotate.dialect.register(key="measure_id")
class Annotate(interp.MethodTable):
Expand Down
27 changes: 21 additions & 6 deletions src/bloqade/analysis/measure_id/lattice.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import final
from dataclasses import dataclass

Expand All @@ -8,8 +9,15 @@
SimpleMeetMixin,
)


class Predicate(Enum):
IS_ZERO = 1
IS_ONE = 2
IS_LOST = 3


# Taken directly from Kai-Hsin Wu's implementation
# with minor changes to names and addition of CanMeasureId type
# with minor changes to names


@dataclass
Expand Down Expand Up @@ -57,18 +65,25 @@ def is_subseteq(self, other: MeasureId) -> bool:

@final
@dataclass
class MeasureIdBool(MeasureId):
class RawMeasureId(MeasureId):
idx: int

def is_subseteq(self, other: MeasureId) -> bool:
if isinstance(other, MeasureIdBool):
if isinstance(other, RawMeasureId):
return self.idx == other.idx
return False


# Might be nice to have some print override
# here so all the CanMeasureId's/other types are consolidated for
# readability
@final
@dataclass
class MeasureIdBool(MeasureId):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is MeasureIdBool used anywhere else in the pipeline? If so, this might be breaking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, I did a ripgrep on our workspace and the only things that depend on it are the original physical dialect (which won't depend on it anymore once my simplification PR gets accepted) and bloqade-circuit (which is why this PR exists haha)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, is this PR a backport candidate anyway? I guess so, right? Please add the label :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be backported ):

Now that I think about this would definitely be something of a breaking change because on the SquinToStim side I've now gone ahead and enforced the following syntax to be necessary:

qs = squin.qalloc(5)
ms = squin.broadcast.measure(qs)
one_meas_result = squin.is_one(ms[0])

if one_meas_result:
	squin.x(qs[1])
	
# This can become:
# CX rec[-1] 1

Whereas historically you didn't need is_one at all and you could just chuck in your measurement result as the condition for the scf.IfElse.

idx: int
predicate: Predicate

def is_subseteq(self, other: MeasureId) -> bool:
if isinstance(other, MeasureIdBool):
return self.predicate == other.predicate and self.idx == other.idx
return False


@final
Expand Down
3 changes: 3 additions & 0 deletions src/bloqade/qubit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from ._prelude import kernel as kernel
from .stdlib.simple import (
reset as reset,
is_one as is_one,
is_lost as is_lost,
is_zero as is_zero,
measure as measure,
get_qubit_id as get_qubit_id,
get_measurement_id as get_measurement_id,
Expand Down
50 changes: 49 additions & 1 deletion src/bloqade/qubit/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from bloqade.types import Qubit, MeasurementResult

from .stmts import New, Reset, Measure, QubitId, MeasurementId
from .stmts import New, IsOne, Reset, IsLost, IsZero, Measure, QubitId, MeasurementId


@wraps(New)
Expand Down Expand Up @@ -47,3 +47,51 @@ def get_measurement_id(

@wraps(Reset)
def reset(qubits: ilist.IList[Qubit, Any]) -> None: ...


@wraps(IsZero)
def is_zero(
measurements: ilist.IList[MeasurementResult, N],
) -> ilist.IList[bool, N]:
"""
Check if each measurement result in a list corresponds to a measured value of 0.

Args:
measurements (IList[MeasurementResult, N]): The list of measurements to check.

Returns:
IList[bool, N]: A list of booleans indicating whether each measurement result is 0.
"""

...


@wraps(IsOne)
def is_one(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
"""
Check if each measurement result in a list corresponds to a measured value of 1.

Args:
measurements (IList[MeasurementResult, N]): The list of measurements to check.

Returns:
IList[bool, N]: A list of booleans indicating whether each measurement result is 1.
"""
...


@wraps(IsLost)
def is_lost(
measurements: ilist.IList[MeasurementResult, N],
) -> ilist.IList[bool, N]:
"""
Check if each measurement result in a list corresponds to a lost atom.

Args:
measurements (IList[MeasurementResult, N]): The list of measurements to check.

Returns:
IList[bool, N]: A list of booleans indicating whether each measurement indicates the atom was lost.

"""
...
36 changes: 36 additions & 0 deletions src/bloqade/qubit/stdlib/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,39 @@ def get_measurement_id(
measurement_ids (IList[int, N]): The list of global, unique IDs of the measurements.
"""
return _qubit.get_measurement_id(measurements)


@kernel
def is_zero(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
"""Check if each MeasurementResult in the list is equivalent to measuring the zero state.

Args:
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
Returns:
IList[bool, N]: A list of booleans indicating whether each MeasurementResult is equivalent to the zero state.
"""
return _qubit.is_zero(measurements)


@kernel
def is_one(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
"""Check if each MeasurementResult in the list is equivalent to measuring the one state.

Args:
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
Returns:
IList[bool, N]: A list of booleans indicating whether each MeasurementResult is equivalent to the one state.
"""
return _qubit.is_one(measurements)


@kernel
def is_lost(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
"""Check if each MeasurementResult in the list indicates atom loss.

Args:
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
Returns:
IList[bool, N]: A list of booleans indicating whether each MeasurementResult indicates atom loss.
"""
return _qubit.is_lost(measurements)
41 changes: 41 additions & 0 deletions src/bloqade/qubit/stdlib/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,44 @@ def get_measurement_id(measurement: MeasurementResult) -> int:
"""
ids = broadcast.get_measurement_id(ilist.IList([measurement]))
return ids[0]


@kernel
def is_zero(measurement: MeasurementResult) -> bool:
"""Check if the measurement result is equivalent to measuring the zero state.

Args:
measurement (MeasurementResult): The measurement result to check.
Returns:
bool: True if the measurement result is equivalent to measuring the zero state, False otherwise.

"""
results = broadcast.is_zero(ilist.IList([measurement]))
return results[0]


@kernel
def is_one(measurement: MeasurementResult) -> bool:
"""Check if the measurement result is equivalent to measuring the one state.

Args:
measurement (MeasurementResult): The measurement result to check.
Returns:
bool: True if the measurement result is equivalent to measuring the one state, False otherwise.

"""
results = broadcast.is_one(ilist.IList([measurement]))
return results[0]


@kernel
def is_lost(measurement: MeasurementResult) -> bool:
"""Check if the measurement result indicates atom loss.

Args:
measurement (MeasurementResult): The measurement result to check.
Returns:
bool: True if the measurement result indicates atom loss, False otherwise.
"""
results = broadcast.is_lost(ilist.IList([measurement]))
return results[0]
24 changes: 24 additions & 0 deletions src/bloqade/qubit/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ class Reset(ir.Statement):
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])


@statement
class MeasurementPredicate(ir.Statement):
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
measurements: ir.SSAValue = info.argument(
ilist.IListType[MeasurementResultType, Len]
)
result: ir.ResultValue = info.result(ilist.IListType[types.Bool, Len])


@statement(dialect=dialect)
class IsZero(MeasurementPredicate):
pass


@statement(dialect=dialect)
class IsOne(MeasurementPredicate):
pass


@statement(dialect=dialect)
class IsLost(MeasurementPredicate):
pass


# TODO: investigate why this is needed to get type inference to be correct.
@dialect.register(key="typeinfer")
class __TypeInfer(interp.MethodTable):
Expand Down
3 changes: 3 additions & 0 deletions src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from .. import qubit as qubit, annotate as annotate
from ..qubit import (
reset as reset,
is_one as is_one,
qalloc as qalloc,
is_lost as is_lost,
is_zero as is_zero,
measure as measure,
get_qubit_id as get_qubit_id,
get_measurement_id as get_measurement_id,
Expand Down
8 changes: 7 additions & 1 deletion src/bloqade/squin/stdlib/broadcast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,10 @@
two_qubit_pauli_channel as two_qubit_pauli_channel,
single_qubit_pauli_channel as single_qubit_pauli_channel,
)
from ._qubit import reset as reset, measure as measure
from ._qubit import (
reset as reset,
is_one as is_one,
is_lost as is_lost,
is_zero as is_zero,
measure as measure,
)
3 changes: 3 additions & 0 deletions src/bloqade/squin/stdlib/broadcast/_qubit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from bloqade.qubit.stdlib.broadcast import (
reset as reset,
is_one as is_one,
is_lost as is_lost,
is_zero as is_zero,
measure as measure,
)
8 changes: 4 additions & 4 deletions src/bloqade/stim/rewrite/get_record_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kirin.dialects import py

from bloqade.stim.dialects import auxiliary
from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple
from bloqade.analysis.measure_id.lattice import RawMeasureId, MeasureIdTuple


def insert_get_records(
Expand All @@ -12,9 +12,9 @@ def insert_get_records(
Insert GetRecord statements before the given node
"""
get_record_ssas = []
for measure_id_bool in measure_id_tuple.data:
assert isinstance(measure_id_bool, MeasureIdBool)
target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt
for measure_id in measure_id_tuple.data:
assert isinstance(measure_id, RawMeasureId)
target_rec_idx = (measure_id.idx - 1) - meas_count_at_stmt
idx_stmt = py.constant.Constant(target_rec_idx)
idx_stmt.insert_before(node)
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
Expand Down
25 changes: 13 additions & 12 deletions src/bloqade/stim/rewrite/ifs_to_stim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ
from bloqade.analysis.measure_id import MeasureIDFrame
from bloqade.stim.dialects.auxiliary import GetRecord
from bloqade.analysis.measure_id.lattice import (
MeasureIdBool,
)
from bloqade.analysis.measure_id.lattice import Predicate, MeasureIdBool


@dataclass
Expand Down Expand Up @@ -139,8 +137,13 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:

def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:

# Check the condition is a singular MeasurementIdBool
if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool):
condition_type = self.measure_frame.entries.get(stmt.cond)
# Check the condition is a singular MeasurementIdBool and that
# it was generated by querying if the measurement is equivalent to the one state
if not isinstance(condition_type, MeasureIdBool):
return RewriteResult()

if condition_type.predicate != Predicate.IS_ONE:
return RewriteResult()

# Reusing code from SplitIf,
Expand All @@ -158,14 +161,12 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
else:
return RewriteResult()

# get necessary measurement ID type from analysis
measure_id_bool = self.measure_frame.entries[stmt.cond]
assert isinstance(measure_id_bool, MeasureIdBool)

# generate get record statement
measure_id_idx_stmt = py.Constant(
(measure_id_bool.idx - 1) - self.measure_frame.num_measures_at_stmt[stmt]
)
num_measures = self.measure_frame.num_measures_at_stmt.get(stmt)
if num_measures is None:
return RewriteResult()

measure_id_idx_stmt = py.Constant((condition_type.idx - 1) - num_measures)
get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841

address_attr = stmts[0].qubits.hints.get("address")
Expand Down
Loading