Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bloqade/analysis/measure_id/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import impls as impls
from .lattice import RecId as RecId
from .analysis import (
MeasureIDFrame as MeasureIDFrame,
MeasurementIDAnalysis as MeasurementIDAnalysis,
Expand Down
13 changes: 9 additions & 4 deletions src/bloqade/analysis/measure_id/analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import field, dataclass

from kirin import ir
from kirin.interp import StatementResult
from kirin.analysis import ForwardExtra
from typing_extensions import Self
from kirin.analysis.forward import ForwardFrame
Expand All @@ -17,8 +18,6 @@ class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]):

keys = ["measure_id"]
lattice = MeasureId
# for every kind of measurement encountered, increment this
# then use this to generate the negative values for target rec indices
measure_count = 0
detector_count = 0
observable_count = 0
Expand All @@ -34,12 +33,18 @@ def initialize_frame(
) -> MeasureIDFrame:
return MeasureIDFrame(node, has_parent_access=has_parent_access)

# Still default to bottom,
# but let constants return the softer "NoMeasureId" type from impl
def eval_fallback(
self, frame: ForwardFrame[MeasureId], node: ir.Statement
) -> tuple[MeasureId, ...]:
return tuple(NotMeasureId() for _ in node.results)

def frame_eval(
self, frame: MeasureIDFrame, node: ir.Statement
) -> StatementResult[MeasureId]:
method = self.lookup_registry(frame, node)
if method is not None:
return method(self, frame, node)
return self.eval_fallback(frame, node)

def method_self(self, method: ir.Method) -> MeasureId:
return self.lattice.bottom()
41 changes: 41 additions & 0 deletions src/bloqade/analysis/measure_id/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@

from bloqade import qubit
from bloqade.decoders.dialects import annotate
from bloqade.record_idx_helper import (
GetRecIdxFromPredicate,
GetRecIdxFromMeasurement,
dialect as record_idx_helper_dialect,
)
from bloqade.gemini.logical.dialects import operations

from .lattice import (
RecId,
MeasureId,
Predicate,
DetectorId,
Expand Down Expand Up @@ -316,3 +322,38 @@ def if_else(
return else_results
case _:
return interp_.join_results(then_results, else_results)


@record_idx_helper_dialect.register(key="measure_id")
class RecordIdxHelperAnalysis(interp.MethodTable):

@interp.impl(GetRecIdxFromMeasurement)
def get_rec_idx_from_measurement(
self,
interp_: MeasurementIDAnalysis,
frame: MeasureIDFrame,
stmt: GetRecIdxFromMeasurement,
):
measurement_id = frame.get(stmt.measurement)
if not isinstance(measurement_id, (RawMeasureId, MeasureIdBool)):
return (InvalidMeasureId(),)
computed_idx = (measurement_id.idx - 1) - interp_.measure_count
predicate = (
measurement_id.predicate
if isinstance(measurement_id, MeasureIdBool)
else None
)
return (RecId(idx=computed_idx, predicate=predicate),)

@interp.impl(GetRecIdxFromPredicate)
def get_rec_idx_from_predicate(
self,
interp_: MeasurementIDAnalysis,
frame: MeasureIDFrame,
stmt: GetRecIdxFromPredicate,
):
measurement_id = frame.get(stmt.predicate_result)
if not isinstance(measurement_id, MeasureIdBool):
return (InvalidMeasureId(),)
computed_idx = (measurement_id.idx - 1) - interp_.measure_count
return (RecId(idx=computed_idx, predicate=measurement_id.predicate),)
7 changes: 7 additions & 0 deletions src/bloqade/analysis/measure_id/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ class MeasureIdBool(ConcreteMeasureId):
predicate: Predicate


@final
@dataclass
class RecId(ConcreteMeasureId):
idx: int
predicate: Predicate | None


@final
@dataclass
class DetectorId(MeasureId):
Expand Down
5 changes: 5 additions & 0 deletions src/bloqade/record_idx_helper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .stmts import (
GetRecIdxFromPredicate as GetRecIdxFromPredicate,
GetRecIdxFromMeasurement as GetRecIdxFromMeasurement,
)
from ._dialect import dialect as dialect
3 changes: 3 additions & 0 deletions src/bloqade/record_idx_helper/_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from kirin import ir

dialect = ir.Dialect("record_idx_helper")
22 changes: 22 additions & 0 deletions src/bloqade/record_idx_helper/stmts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from kirin import ir, types
from kirin.decl import info, statement

from bloqade.types import MeasurementResultType

from ._dialect import dialect


@statement(dialect=dialect)
class GetRecIdxFromMeasurement(ir.Statement):
name = "get_rec_idx_from_measurement"
traits = frozenset({ir.Pure()})
measurement: ir.SSAValue = info.argument(type=MeasurementResultType)
result: ir.ResultValue = info.result(type=types.Int)


@statement(dialect=dialect)
class GetRecIdxFromPredicate(ir.Statement):
name = "get_rec_idx_from_predicate"
traits = frozenset({ir.Pure()})
predicate_result: ir.SSAValue = info.argument(type=types.Bool)
result: ir.ResultValue = info.result(type=types.Int)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .analysis import StimFromSquinValidation as StimFromSquinValidation
119 changes: 119 additions & 0 deletions src/bloqade/stim/analysis/from_squin_validation/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Any

from kirin import ir, interp
from kirin.lattice import EmptyLattice
from kirin.analysis import Forward
from kirin.dialects import scf
from kirin.validation import ValidationPass
from kirin.analysis.forward import ForwardFrame

from bloqade.qubit import stmts as qubit_stmts
from bloqade.squin import gate
from bloqade.qubit._dialect import dialect as qubit_dialect

PauliGateType = (gate.stmts.X, gate.stmts.Y, gate.stmts.Z)


class _StimIfElseValidationAnalysis(Forward[EmptyLattice]):
keys = ["stim.validate.from_squin"]

lattice = EmptyLattice

def method_self(self, method: ir.Method) -> EmptyLattice:
return self.lattice.bottom()

def eval_fallback(
self, frame: ForwardFrame[EmptyLattice], node: ir.Statement
) -> tuple[EmptyLattice, ...]:
return tuple(self.lattice.bottom() for _ in range(len(node.results)))


@scf.dialect.register(key="stim.validate.from_squin")
class _ScfMethods(interp.MethodTable):

@interp.impl(scf.IfElse)
def if_else(
self,
interp_: _StimIfElseValidationAnalysis,
frame: ForwardFrame[EmptyLattice],
stmt: scf.IfElse,
):
for child in stmt.walk(include_self=False):
if isinstance(child, scf.IfElse):
interp_.add_validation_error(
stmt,
ir.ValidationError(
stmt,
"Nested IfElse statements are not supported in rewriting to Stim IR.",
),
)
break

if stmt.else_body.blocks and not (
len(stmt.else_body.blocks[0].stmts) == 1
and isinstance(stmt.else_body.blocks[0].last_stmt, scf.Yield)
):
interp_.add_validation_error(
stmt,
ir.ValidationError(
stmt,
"IfElse statements with an else body are not supported in rewriting to Stim IR.",
),
)

for child in stmt.then_body.walk():
if isinstance(child, gate.stmts.Gate) and not isinstance(
child, PauliGateType
):
interp_.add_validation_error(
stmt,
ir.ValidationError(
stmt,
f"Only Pauli gates (X, Y, Z) are allowed inside an scf.IfElse "
f"'then'-body for rewriting to Stim IR. Found: {type(child).__name__}",
),
)


@qubit_dialect.register(key="stim.validate.from_squin")
class _QubitMethods(interp.MethodTable):

@interp.impl(qubit_stmts.IsZero)
def is_zero(
self,
interp_: _StimIfElseValidationAnalysis,
frame: ForwardFrame[EmptyLattice],
stmt: qubit_stmts.IsZero,
):
interp_.add_validation_error(
stmt,
ir.ValidationError(
stmt,
"is_zero predicate is not supported in rewriting to Stim IR. Only the is_one predicate is supported.",
),
)

@interp.impl(qubit_stmts.IsLost)
def is_lost(
self,
interp_: _StimIfElseValidationAnalysis,
frame: ForwardFrame[EmptyLattice],
stmt: qubit_stmts.IsLost,
):
interp_.add_validation_error(
stmt,
ir.ValidationError(
stmt,
"is_lost predicate is not supported in rewriting to Stim IR. Only the is_one predicate is supported.",
),
)


class StimFromSquinValidation(ValidationPass):
def name(self) -> str:
return "Stim from Squin Validation"

def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]:
analysis = _StimIfElseValidationAnalysis(method.dialects)
frame, _ = analysis.run(method)
return frame, analysis.get_validation_errors()
Loading
Loading