Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ repos:
- id: check-yaml
args: ['--unsafe']
- id: end-of-file-fixer
exclude: .*\.stim$
- id: trailing-whitespace
- repo: https://github.com/pycqa/isort
rev: 6.0.1
Expand Down
16 changes: 15 additions & 1 deletion src/bloqade/analysis/measure_id/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kirin.analysis import const
from kirin.dialects import py, scf, func, ilist

from bloqade import qubit
from bloqade import qubit, annotate

from .lattice import (
AnyMeasureId,
Expand Down Expand Up @@ -46,6 +46,20 @@ def measure_qubit_list(
return (MeasureIdTuple(data=tuple(measure_id_bools)),)


@annotate.dialect.register(key="measure_id")
class Annotate(interp.MethodTable):
@interp.impl(annotate.stmts.SetObservable)
@interp.impl(annotate.stmts.SetDetector)
def consumes_measurement_results(
self,
interp: MeasurementIDAnalysis,
frame: MeasureIDFrame,
stmt: annotate.stmts.SetObservable | annotate.stmts.SetDetector,
):
frame.num_measures_at_stmt[stmt] = interp.measure_count
return (NotMeasureId(),)


@ilist.dialect.register(key="measure_id")
class IList(interp.MethodTable):
@interp.impl(ilist.New)
Expand Down
6 changes: 6 additions & 0 deletions src/bloqade/annotate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import stmts as stmts
from ._dialect import dialect as dialect
from ._interface import (
set_detector as set_detector,
set_observable as set_observable,
)
3 changes: 3 additions & 0 deletions src/bloqade/annotate/_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from kirin import ir

dialect = ir.Dialect("squin.annotate")
22 changes: 22 additions & 0 deletions src/bloqade/annotate/_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any

from kirin.dialects import ilist
from kirin.lowering import wraps

from bloqade.types import MeasurementResult

from .stmts import SetDetector, SetObservable
from .types import Detector, Observable


@wraps(SetDetector)
def set_detector(
measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
coordinates: ilist.IList[float | int, Any] | list[float | int],
) -> Detector: ...


@wraps(SetObservable)
def set_observable(
measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
) -> Observable: ...
29 changes: 29 additions & 0 deletions src/bloqade/annotate/stmts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from kirin import ir, types as kirin_types, lowering
from kirin.decl import info, statement
from kirin.dialects import ilist

from bloqade.types import MeasurementResultType
from bloqade.annotate.types import DetectorType, ObservableType

from ._dialect import dialect


@statement
class ConsumesMeasurementResults(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
measurements: ir.SSAValue = info.argument(
ilist.IListType[MeasurementResultType, kirin_types.Any]
)


@statement(dialect=dialect)
class SetDetector(ConsumesMeasurementResults):
coordinates: ir.SSAValue = info.argument(
type=ilist.IListType[kirin_types.Int | kirin_types.Float, kirin_types.Any]
)
result: ir.ResultValue = info.result(DetectorType)


@statement(dialect=dialect)
class SetObservable(ConsumesMeasurementResults):
result: ir.ResultValue = info.result(ObservableType)
13 changes: 13 additions & 0 deletions src/bloqade/annotate/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from kirin import types


class Detector:
pass


class Observable:
pass


DetectorType = types.PyClass(Detector)
ObservableType = types.PyClass(Observable)
3 changes: 2 additions & 1 deletion src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
noise as noise,
analysis as analysis,
)
from .. import qubit as qubit
from .. import qubit as qubit, annotate as annotate
from ..qubit import (
reset as reset,
qalloc as qalloc,
Expand All @@ -12,6 +12,7 @@
get_measurement_id as get_measurement_id,
)
from .groups import kernel as kernel
from ..annotate import set_detector as set_detector, set_observable as set_observable
from .stdlib.simple import (
h as h,
s as s,
Expand Down
4 changes: 2 additions & 2 deletions src/bloqade/squin/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from kirin.dialects import debug, ilist

from . import gate, noise
from .. import qubit
from .. import qubit, annotate


@ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug]))
@ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug, annotate]))
def kernel(self):
fold_pass = passes.Fold(self)
typeinfer_pass = passes.TypeInfer(self)
Expand Down
4 changes: 3 additions & 1 deletion src/bloqade/stim/passes/squin_to_stim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from bloqade.analysis.measure_id import MeasurementIDAnalysis
from bloqade.stim.passes.flatten import Flatten

from ..rewrite.ifs_to_stim import IfToStim
from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim


@dataclass
Expand Down Expand Up @@ -64,6 +64,8 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
rewrite_result = (
Chain(
Walk(IfToStim(measure_frame=meas_analysis_frame)),
Walk(SetDetectorToStim(measure_id_frame=meas_analysis_frame)),
Walk(SetObservableToStim(measure_id_frame=meas_analysis_frame)),
Fixpoint(Walk(DeadCodeElimination())),
)
.rewrite(mt.code)
Expand Down
2 changes: 2 additions & 0 deletions src/bloqade/stim/rewrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
from .py_constant_to_stim import PyConstantToStim as PyConstantToStim
from .set_detector_to_stim import SetDetectorToStim as SetDetectorToStim
from .set_observable_to_stim import SetObservableToStim as SetObservableToStim
24 changes: 24 additions & 0 deletions src/bloqade/stim/rewrite/get_record_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from kirin import ir
from kirin.dialects import py

from bloqade.stim.dialects import auxiliary
from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple


def insert_get_records(
node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count_at_stmt: int
):
"""
Insert GetRecord statements before the given node
"""
get_record_ssas = []
for measure_id_bool in measure_id_tuple.data:
assert isinstance(measure_id_bool, MeasureIdBool)
target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt
idx_stmt = py.constant.Constant(target_rec_idx)
idx_stmt.insert_before(node)
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
get_record_stmt.insert_before(node)
get_record_ssas.append(get_record_stmt.result)

return get_record_ssas
68 changes: 68 additions & 0 deletions src/bloqade/stim/rewrite/set_detector_to_stim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Iterable
from dataclasses import dataclass

from kirin import ir
from kirin.dialects.py import Constant
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade.stim.dialects import auxiliary
from bloqade.annotate.stmts import SetDetector
from bloqade.analysis.measure_id import MeasureIDFrame
from bloqade.stim.dialects.auxiliary import Detector
from bloqade.analysis.measure_id.lattice import MeasureIdTuple

from ..rewrite.get_record_util import insert_get_records


@dataclass
class SetDetectorToStim(RewriteRule):
"""
Rewrite SetDetector to GetRecord and Detector in the stim dialect
"""

measure_id_frame: MeasureIDFrame

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
match node:
case SetDetector():
return self.rewrite_SetDetector(node)
case _:
return RewriteResult()

def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult:

# get coordinates and generate correct consts
coord_ssas = []
if not isinstance(node.coordinates.owner, Constant):
return RewriteResult()

coord_values = node.coordinates.owner.value.unwrap()

if not isinstance(coord_values, Iterable):
return RewriteResult()

if any(not isinstance(value, (int, float)) for value in coord_values):
return RewriteResult()

for coord_value in coord_values:
if isinstance(coord_value, float):
coord_stmt = auxiliary.ConstFloat(value=coord_value)
else: # int
coord_stmt = auxiliary.ConstInt(value=coord_value)
coord_ssas.append(coord_stmt.result)
coord_stmt.insert_before(node)

measure_ids = self.measure_id_frame.entries[node.measurements]
assert isinstance(measure_ids, MeasureIdTuple)

get_record_list = insert_get_records(
node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
)

detector_stmt = Detector(
coord=tuple(coord_ssas), targets=tuple(get_record_list)
)

node.replace_by(detector_stmt)

return RewriteResult(has_done_something=True)
52 changes: 52 additions & 0 deletions src/bloqade/stim/rewrite/set_observable_to_stim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from dataclasses import dataclass

from kirin import ir
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade.stim.dialects import auxiliary
from bloqade.annotate.stmts import SetObservable
from bloqade.analysis.measure_id import MeasureIDFrame
from bloqade.stim.dialects.auxiliary import ObservableInclude
from bloqade.analysis.measure_id.lattice import MeasureIdTuple

from ..rewrite.get_record_util import insert_get_records


@dataclass
class SetObservableToStim(RewriteRule):
"""
Rewrite SetObservable to GetRecord and ObservableInclude in the stim dialect
"""

measure_id_frame: MeasureIDFrame

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
match node:
case SetObservable():
return self.rewrite_SetObservable(node)
case _:
return RewriteResult()

def rewrite_SetObservable(self, node: SetObservable) -> RewriteResult:

# set idx to 0 for now, but this
# should be something that a user can set on their own.
# SetObservable needs to accept an int.

idx_stmt = auxiliary.ConstInt(value=0)
idx_stmt.insert_before(node)

measure_ids = self.measure_id_frame.entries[node.measurements]
assert isinstance(measure_ids, MeasureIdTuple)

get_record_list = insert_get_records(
node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
)

observable_include_stmt = ObservableInclude(
idx=idx_stmt.result, targets=tuple(get_record_list)
)

node.replace_by(observable_include_stmt)

return RewriteResult(has_done_something=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

X 0 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

RZ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
H 13 14 15 16
CX 13 14 16 1 3 7 0 2 4 10 11 12
CX 13 14 16 2 4 8 3 5 7 10 11 12
CX 13 15 16 0 4 6 1 3 5 9 10 11
CX 13 15 16 1 5 7 4 6 8 9 10 11
H 13 14 15 16
MZ(0.00000000) 13 14 15 16
MZ(0.00000000) 9 10 11 12
RZ 13 14 15 16 9 10 11 12
DETECTOR(0, 0) rec[-4]
DETECTOR(0, 0) rec[-3]
DETECTOR(0, 0) rec[-2]
DETECTOR(0, 0) rec[-1]
H 13 14 15 16
CX 13 14 16 1 3 7 0 2 4 10 11 12
CX 13 14 16 2 4 8 3 5 7 10 11 12
CX 13 15 16 0 4 6 1 3 5 9 10 11
CX 13 15 16 1 5 7 4 6 8 9 10 11
H 13 14 15 16
MZ(0.00000000) 13 14 15 16
MZ(0.00000000) 9 10 11 12
RZ 13 14 15 16 9 10 11 12
DETECTOR(0, 0) rec[-4] rec[-12]
DETECTOR(0, 0) rec[-3] rec[-11]
DETECTOR(0, 0) rec[-2] rec[-10]
DETECTOR(0, 0) rec[-1] rec[-9]
DETECTOR(0, 0) rec[-8] rec[-16]
DETECTOR(0, 0) rec[-7] rec[-15]
DETECTOR(0, 0) rec[-6] rec[-14]
DETECTOR(0, 0) rec[-5] rec[-13]
MZ(0.00000000) 0 1 2 3 4 5 6 7 8
DETECTOR(0, 0) rec[-9] rec[-8] rec[-13]
DETECTOR(0, 0) rec[-8] rec[-7] rec[-5] rec[-4] rec[-12]
DETECTOR(0, 0) rec[-6] rec[-5] rec[-3] rec[-2] rec[-11]
DETECTOR(0, 0) rec[-2] rec[-1] rec[-10]
OBSERVABLE_INCLUDE(0) rec[-9] rec[-8] rec[-7]
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

MZ(0.00000000) 0 1
DETECTOR(0, 1) rec[-2] rec[-1]
DETECTOR(3, 4) rec[-2] rec[-1]
DETECTOR(0, 5.00000000) rec[-2] rec[-1]
DETECTOR(5.00000000, 3) rec[-2] rec[-1]
DETECTOR(1, 2, 5.00000000) rec[-2] rec[-1]
DETECTOR(1, 2) rec[-2] rec[-1]
DETECTOR(1, 2, 5.00000000) rec[-2] rec[-1]
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

H 1 3 2
MZ(0.00000000) 1 3 2
DETECTOR(0, 0) rec[-3]
DETECTOR(0, 0) rec[-2]
DETECTOR(0, 0) rec[-1]
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

X 0
Y 1
Z 2
CX 0 1
CX 0 1 2 3
Z 0 1 2 3
MZ(0.00000000) 0 1 2 3
DETECTOR(0.00000000, 0.00000000) rec[-4] rec[-3]
DETECTOR(1.00000000, 0.00000000) rec[-3] rec[-2]
OBSERVABLE_INCLUDE(0) rec[-2]
Loading