Skip to content

Commit e85223f

Browse files
committed
initial annotate dialect implementation
1 parent 1642980 commit e85223f

File tree

13 files changed

+257
-5
lines changed

13 files changed

+257
-5
lines changed

src/bloqade/analysis/measure_id/impls.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.analysis import const
33
from kirin.dialects import py, scf, func, ilist
44

5-
from bloqade import qubit
5+
from bloqade import qubit, annotate
66

77
from .lattice import (
88
AnyMeasureId,
@@ -46,6 +46,20 @@ def measure_qubit_list(
4646
return (MeasureIdTuple(data=tuple(measure_id_bools)),)
4747

4848

49+
@annotate.dialect.register(key="measure_id")
50+
class Annotate(interp.MethodTable):
51+
@interp.impl(annotate.stmts.SetObservable)
52+
@interp.impl(annotate.stmts.SetDetector)
53+
def consumes_measurement_results(
54+
self,
55+
interp: MeasurementIDAnalysis,
56+
frame: MeasureIDFrame,
57+
stmt: annotate.stmts.SetObservable | annotate.stmts.SetDetector,
58+
):
59+
frame.num_measures_at_stmt[stmt] = interp.measure_count
60+
return (NotMeasureId(),)
61+
62+
4963
@ilist.dialect.register(key="measure_id")
5064
class IList(interp.MethodTable):
5165
@interp.impl(ilist.New)

src/bloqade/annotate/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from . import stmts as stmts
2+
from ._dialect import dialect as dialect
3+
from ._interface import (
4+
set_detector as set_detector,
5+
set_observable as set_observable,
6+
)

src/bloqade/annotate/_dialect.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect("squin.annotate")

src/bloqade/annotate/_interface.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Any
2+
3+
from kirin.dialects import ilist
4+
from kirin.lowering import wraps
5+
6+
from bloqade.types import MeasurementResult
7+
8+
from .stmts import SetDetector, SetObservable
9+
10+
11+
@wraps(SetDetector)
12+
def set_detector(
13+
measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
14+
coordinates: tuple[float | int, ...],
15+
) -> None: ...
16+
17+
18+
@wraps(SetObservable)
19+
def set_observable(
20+
measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
21+
) -> None: ...

src/bloqade/annotate/stmts.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from kirin import ir, types as kirin_types, lowering
2+
from kirin.decl import info, statement
3+
from kirin.dialects import ilist
4+
5+
from bloqade.types import MeasurementResultType
6+
7+
from ._dialect import dialect
8+
9+
10+
@statement
11+
class ConsumesMeasurementResults(ir.Statement):
12+
traits = frozenset({lowering.FromPythonCall()})
13+
inputs: ir.SSAValue = info.argument(
14+
ilist.IListType[MeasurementResultType, kirin_types.Any]
15+
)
16+
17+
18+
@statement(dialect=dialect)
19+
class SetDetector(ConsumesMeasurementResults):
20+
coordinates: ir.SSAValue = info.argument(
21+
type=kirin_types.Tuple[kirin_types.Int | kirin_types.Float]
22+
)
23+
24+
25+
@statement(dialect=dialect)
26+
class SetObservable(ConsumesMeasurementResults):
27+
pass

src/bloqade/squin/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
noise as noise,
44
analysis as analysis,
55
)
6-
from .. import qubit as qubit
6+
from .. import qubit as qubit, annotate as annotate
77
from ..qubit import (
88
reset as reset,
99
qalloc as qalloc,
@@ -12,6 +12,7 @@
1212
get_measurement_id as get_measurement_id,
1313
)
1414
from .groups import kernel as kernel
15+
from ..annotate import set_detector as set_detector, set_observable as set_observable
1516
from .stdlib.simple import (
1617
h as h,
1718
s as s,

src/bloqade/squin/groups.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from kirin.dialects import debug, ilist
44

55
from . import gate, noise
6-
from .. import qubit
6+
from .. import qubit, annotate
77

88

9-
@ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug]))
9+
@ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug, annotate]))
1010
def kernel(self):
1111
fold_pass = passes.Fold(self)
1212
typeinfer_pass = passes.TypeInfer(self)

src/bloqade/stim/passes/squin_to_stim.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from bloqade.analysis.measure_id import MeasurementIDAnalysis
2828
from bloqade.stim.passes.flatten import Flatten
2929

30-
from ..rewrite.ifs_to_stim import IfToStim
30+
from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim
3131

3232

3333
@dataclass
@@ -64,6 +64,8 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
6464
rewrite_result = (
6565
Chain(
6666
Walk(IfToStim(measure_frame=meas_analysis_frame)),
67+
Walk(SetDetectorToStim(measure_id_frame=meas_analysis_frame)),
68+
Walk(SetObservableToStim(measure_id_frame=meas_analysis_frame)),
6769
Fixpoint(Walk(DeadCodeElimination())),
6870
)
6971
.rewrite(mt.code)

src/bloqade/stim/rewrite/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
44
from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
55
from .py_constant_to_stim import PyConstantToStim as PyConstantToStim
6+
from .set_detector_to_stim import SetDetectorToStim as SetDetectorToStim
7+
from .set_observable_to_stim import SetObservableToStim as SetObservableToStim
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from kirin import ir
2+
from kirin.dialects import py
3+
4+
from bloqade.stim.dialects import auxiliary
5+
from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple
6+
7+
8+
def insert_get_records(
9+
node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count_at_stmt: int
10+
):
11+
"""
12+
Insert GetRecord statements before the given node
13+
"""
14+
get_record_ssas = []
15+
for measure_id_bool in measure_id_tuple.data:
16+
assert isinstance(measure_id_bool, MeasureIdBool)
17+
target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt
18+
idx_stmt = py.constant.Constant(target_rec_idx)
19+
idx_stmt.insert_before(node)
20+
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
21+
get_record_stmt.insert_before(node)
22+
get_record_ssas.append(get_record_stmt.result)
23+
24+
return get_record_ssas

0 commit comments

Comments
 (0)