Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 1 addition & 19 deletions src/bloqade/analysis/measure_id/analysis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import TypeVar
from dataclasses import field, dataclass

from kirin import ir
from kirin.analysis import ForwardExtra, const
from kirin.analysis import ForwardExtra
from kirin.analysis.forward import ForwardFrame

from .lattice import MeasureId, NotMeasureId
Expand Down Expand Up @@ -33,22 +32,5 @@ def eval_fallback(
) -> tuple[MeasureId, ...]:
return tuple(NotMeasureId() for _ in node.results)

# Xiu-zhe (Roger) Luo came up with this in the address analysis,
# reused here for convenience (now modified to be a bit more graceful)
# TODO: Remove this function once upgrade to kirin 0.18 happens,
# method is built-in to interpreter then

T = TypeVar("T")

def get_const_value(
self, input_type: type[T] | tuple[type[T], ...], value: ir.SSAValue
) -> type[T] | None:
if isinstance(hint := value.hints.get("const"), const.Value):
data = hint.data
if isinstance(data, input_type):
return hint.data

return None

def method_self(self, method: ir.Method) -> MeasureId:
return self.lattice.bottom()
46 changes: 35 additions & 11 deletions src/bloqade/analysis/measure_id/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@
from bloqade import qubit, annotate

from .lattice import (
Predicate,
AnyMeasureId,
NotMeasureId,
RawMeasureId,
MeasureIdBool,
MeasureIdTuple,
InvalidMeasureId,
)
from .analysis import MeasureIDFrame, MeasurementIDAnalysis

## Can't do wire right now because of
## unresolved RFC on return type
# from bloqade.squin import wire


@qubit.dialect.register(key="measure_id")
class SquinQubit(interp.MethodTable):
Expand All @@ -30,7 +28,6 @@ def measure_qubit_list(
):

# try to get the length of the list
## "...safely assume the type inference will give you what you need"
qubits_type = stmt.qubits.type
# vars[0] is just the type of the elements in the ilist,
# vars[1] can contain a literal with length information
Expand All @@ -41,10 +38,41 @@ def measure_qubit_list(
measure_id_bools = []
for _ in range(num_qubits.data):
interp.measure_count += 1
measure_id_bools.append(MeasureIdBool(interp.measure_count))
measure_id_bools.append(RawMeasureId(interp.measure_count))

return (MeasureIdTuple(data=tuple(measure_id_bools)),)

@interp.impl(qubit.stmts.IsLost)
@interp.impl(qubit.stmts.IsOne)
@interp.impl(qubit.stmts.IsZero)
def measurement_predicate(
self,
interp: MeasurementIDAnalysis,
frame: interp.Frame,
stmt: qubit.stmts.IsLost | qubit.stmts.IsOne | qubit.stmts.IsZero,
):
original_measure_id_tuple = frame.get(stmt.measurements)
if not all(
isinstance(measure_id, RawMeasureId)
for measure_id in original_measure_id_tuple.data
):
return (InvalidMeasureId(),)

if isinstance(stmt, qubit.stmts.IsLost):
predicate = Predicate.IS_LOST
elif isinstance(stmt, qubit.stmts.IsOne):
predicate = Predicate.IS_ONE
elif isinstance(stmt, qubit.stmts.IsZero):
predicate = Predicate.IS_ZERO
else:
return (InvalidMeasureId(),)

predicate_measure_ids = [
MeasureIdBool(measure_id.idx, predicate)
for measure_id in original_measure_id_tuple.data
]
return (MeasureIdTuple(data=tuple(predicate_measure_ids)),)


@annotate.dialect.register(key="measure_id")
class Annotate(interp.MethodTable):
Expand Down Expand Up @@ -94,14 +122,10 @@ def getitem(
self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.GetItem
):

idx_or_slice = interp.get_const_value((int, slice), stmt.index)
idx_or_slice = interp.maybe_const(stmt.index, (int, slice))
if idx_or_slice is None:
return (InvalidMeasureId(),)

# hint = stmt.index.hints.get("const")
# if hint is None or not isinstance(hint, const.Value):
# return (InvalidMeasureId(),)

obj = frame.get(stmt.obj)
if isinstance(obj, MeasureIdTuple):
if isinstance(idx_or_slice, slice):
Expand Down
33 changes: 27 additions & 6 deletions src/bloqade/analysis/measure_id/lattice.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import final
from dataclasses import dataclass

Expand All @@ -8,8 +9,21 @@
SimpleMeetMixin,
)


class Predicate(Enum):
IS_ZERO = 1
IS_ONE = 2
IS_LOST = 3

def __str__(self) -> str:
return self.name

def __repr__(self) -> str:
return self.name


# Taken directly from Kai-Hsin Wu's implementation
# with minor changes to names and addition of CanMeasureId type
# with minor changes to names


@dataclass
Expand Down Expand Up @@ -57,18 +71,25 @@ def is_subseteq(self, other: MeasureId) -> bool:

@final
@dataclass
class MeasureIdBool(MeasureId):
class RawMeasureId(MeasureId):
idx: int

def is_subseteq(self, other: MeasureId) -> bool:
if isinstance(other, MeasureIdBool):
if isinstance(other, RawMeasureId):
return self.idx == other.idx
return False


# Might be nice to have some print override
# here so all the CanMeasureId's/other types are consolidated for
# readability
@final
@dataclass
class MeasureIdBool(MeasureId):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is MeasureIdBool used anywhere else in the pipeline? If so, this might be breaking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, I did a ripgrep on our workspace and the only things that depend on it are the original physical dialect (which won't depend on it anymore once my simplification PR gets accepted) and bloqade-circuit (which is why this PR exists haha)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, is this PR a backport candidate anyway? I guess so, right? Please add the label :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be backported ):

Now that I think about this would definitely be something of a breaking change because on the SquinToStim side I've now gone ahead and enforced the following syntax to be necessary:

qs = squin.qalloc(5)
ms = squin.broadcast.measure(qs)
one_meas_result = squin.is_one(ms[0])

if one_meas_result:
	squin.x(qs[1])
	
# This can become:
# CX rec[-1] 1

Whereas historically you didn't need is_one at all and you could just chuck in your measurement result as the condition for the scf.IfElse.

idx: int
predicate: Predicate

def is_subseteq(self, other: MeasureId) -> bool:
if isinstance(other, MeasureIdBool):
return self.predicate == other.predicate and self.idx == other.idx
return False


@final
Expand Down
3 changes: 3 additions & 0 deletions src/bloqade/qubit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from ._prelude import kernel as kernel
from .stdlib.simple import (
reset as reset,
is_one as is_one,
is_lost as is_lost,
is_zero as is_zero,
measure as measure,
get_qubit_id as get_qubit_id,
get_measurement_id as get_measurement_id,
Expand Down
50 changes: 49 additions & 1 deletion src/bloqade/qubit/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from bloqade.types import Qubit, MeasurementResult

from .stmts import New, Reset, Measure, QubitId, MeasurementId
from .stmts import New, IsOne, Reset, IsLost, IsZero, Measure, QubitId, MeasurementId


@wraps(New)
Expand Down Expand Up @@ -47,3 +47,51 @@ def get_measurement_id(

@wraps(Reset)
def reset(qubits: ilist.IList[Qubit, Any]) -> None: ...


@wraps(IsZero)
def is_zero(
measurements: ilist.IList[MeasurementResult, N],
) -> ilist.IList[bool, N]:
"""
Check if each measurement result in a list corresponds to a measured value of 0.

Args:
measurements (IList[MeasurementResult, N]): The list of measurements to check.

Returns:
IList[bool, N]: A list of booleans indicating whether each measurement result is 0.
"""

...


@wraps(IsOne)
def is_one(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
"""
Check if each measurement result in a list corresponds to a measured value of 1.

Args:
measurements (IList[MeasurementResult, N]): The list of measurements to check.

Returns:
IList[bool, N]: A list of booleans indicating whether each measurement result is 1.
"""
...


@wraps(IsLost)
def is_lost(
measurements: ilist.IList[MeasurementResult, N],
) -> ilist.IList[bool, N]:
"""
Check if each measurement result in a list corresponds to a lost atom.

Args:
measurements (IList[MeasurementResult, N]): The list of measurements to check.

Returns:
IList[bool, N]: A list of booleans indicating whether each measurement indicates the atom was lost.

"""
...
36 changes: 36 additions & 0 deletions src/bloqade/qubit/stdlib/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,39 @@ def get_measurement_id(
measurement_ids (IList[int, N]): The list of global, unique IDs of the measurements.
"""
return _qubit.get_measurement_id(measurements)


@kernel
def is_zero(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
"""Check if each MeasurementResult in the list is equivalent to measuring the zero state.

Args:
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
Returns:
IList[bool, N]: A list of booleans indicating whether each MeasurementResult is equivalent to the zero state.
"""
return _qubit.is_zero(measurements)


@kernel
def is_one(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
"""Check if each MeasurementResult in the list is equivalent to measuring the one state.

Args:
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
Returns:
IList[bool, N]: A list of booleans indicating whether each MeasurementResult is equivalent to the one state.
"""
return _qubit.is_one(measurements)


@kernel
def is_lost(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
"""Check if each MeasurementResult in the list indicates atom loss.

Args:
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
Returns:
IList[bool, N]: A list of booleans indicating whether each MeasurementResult indicates atom loss.
"""
return _qubit.is_lost(measurements)
41 changes: 41 additions & 0 deletions src/bloqade/qubit/stdlib/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,44 @@ def get_measurement_id(measurement: MeasurementResult) -> int:
"""
ids = broadcast.get_measurement_id(ilist.IList([measurement]))
return ids[0]


@kernel
def is_zero(measurement: MeasurementResult) -> bool:
"""Check if the measurement result is equivalent to measuring the zero state.

Args:
measurement (MeasurementResult): The measurement result to check.
Returns:
bool: True if the measurement result is equivalent to measuring the zero state, False otherwise.

"""
results = broadcast.is_zero(ilist.IList([measurement]))
return results[0]


@kernel
def is_one(measurement: MeasurementResult) -> bool:
"""Check if the measurement result is equivalent to measuring the one state.

Args:
measurement (MeasurementResult): The measurement result to check.
Returns:
bool: True if the measurement result is equivalent to measuring the one state, False otherwise.

"""
results = broadcast.is_one(ilist.IList([measurement]))
return results[0]


@kernel
def is_lost(measurement: MeasurementResult) -> bool:
"""Check if the measurement result indicates atom loss.

Args:
measurement (MeasurementResult): The measurement result to check.
Returns:
bool: True if the measurement result indicates atom loss, False otherwise.
"""
results = broadcast.is_lost(ilist.IList([measurement]))
return results[0]
24 changes: 24 additions & 0 deletions src/bloqade/qubit/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ class Reset(ir.Statement):
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])


@statement
class MeasurementPredicate(ir.Statement):
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
measurements: ir.SSAValue = info.argument(
ilist.IListType[MeasurementResultType, Len]
)
result: ir.ResultValue = info.result(ilist.IListType[types.Bool, Len])


@statement(dialect=dialect)
class IsZero(MeasurementPredicate):
pass


@statement(dialect=dialect)
class IsOne(MeasurementPredicate):
pass


@statement(dialect=dialect)
class IsLost(MeasurementPredicate):
pass


# TODO: investigate why this is needed to get type inference to be correct.
@dialect.register(key="typeinfer")
class __TypeInfer(interp.MethodTable):
Expand Down
3 changes: 3 additions & 0 deletions src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from .. import qubit as qubit, annotate as annotate
from ..qubit import (
reset as reset,
is_one as is_one,
qalloc as qalloc,
is_lost as is_lost,
is_zero as is_zero,
measure as measure,
get_qubit_id as get_qubit_id,
get_measurement_id as get_measurement_id,
Expand Down
8 changes: 7 additions & 1 deletion src/bloqade/squin/stdlib/broadcast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,10 @@
two_qubit_pauli_channel as two_qubit_pauli_channel,
single_qubit_pauli_channel as single_qubit_pauli_channel,
)
from ._qubit import reset as reset, measure as measure
from ._qubit import (
reset as reset,
is_one as is_one,
is_lost as is_lost,
is_zero as is_zero,
measure as measure,
)
3 changes: 3 additions & 0 deletions src/bloqade/squin/stdlib/broadcast/_qubit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from bloqade.qubit.stdlib.broadcast import (
reset as reset,
is_one as is_one,
is_lost as is_lost,
is_zero as is_zero,
measure as measure,
)
Loading