Skip to content

Commit 94505d5

Browse files
authored
detector/observable annotation dialect (#603)
1 parent 2ffb517 commit 94505d5

28 files changed

+831
-5
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ repos:
77
- id: check-yaml
88
args: ['--unsafe']
99
- id: end-of-file-fixer
10+
exclude: .*\.stim$
1011
- id: trailing-whitespace
1112
- repo: https://github.com/pycqa/isort
1213
rev: 6.0.1

src/bloqade/analysis/measure_id/impls.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.analysis import const
33
from kirin.dialects import py, scf, func, ilist
44

5-
from bloqade import qubit
5+
from bloqade import qubit, annotate
66

77
from .lattice import (
88
AnyMeasureId,
@@ -46,6 +46,20 @@ def measure_qubit_list(
4646
return (MeasureIdTuple(data=tuple(measure_id_bools)),)
4747

4848

49+
@annotate.dialect.register(key="measure_id")
50+
class Annotate(interp.MethodTable):
51+
@interp.impl(annotate.stmts.SetObservable)
52+
@interp.impl(annotate.stmts.SetDetector)
53+
def consumes_measurement_results(
54+
self,
55+
interp: MeasurementIDAnalysis,
56+
frame: MeasureIDFrame,
57+
stmt: annotate.stmts.SetObservable | annotate.stmts.SetDetector,
58+
):
59+
frame.num_measures_at_stmt[stmt] = interp.measure_count
60+
return (NotMeasureId(),)
61+
62+
4963
@ilist.dialect.register(key="measure_id")
5064
class IList(interp.MethodTable):
5165
@interp.impl(ilist.New)

src/bloqade/annotate/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from . import stmts as stmts
2+
from ._dialect import dialect as dialect
3+
from ._interface import (
4+
set_detector as set_detector,
5+
set_observable as set_observable,
6+
)

src/bloqade/annotate/_dialect.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect("squin.annotate")

src/bloqade/annotate/_interface.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Any
2+
3+
from kirin.dialects import ilist
4+
from kirin.lowering import wraps
5+
6+
from bloqade.types import MeasurementResult
7+
8+
from .stmts import SetDetector, SetObservable
9+
from .types import Detector, Observable
10+
11+
12+
@wraps(SetDetector)
13+
def set_detector(
14+
measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
15+
coordinates: ilist.IList[float | int, Any] | list[float | int],
16+
) -> Detector: ...
17+
18+
19+
@wraps(SetObservable)
20+
def set_observable(
21+
measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
22+
) -> Observable: ...

src/bloqade/annotate/stmts.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from kirin import ir, types as kirin_types, lowering
2+
from kirin.decl import info, statement
3+
from kirin.dialects import ilist
4+
5+
from bloqade.types import MeasurementResultType
6+
from bloqade.annotate.types import DetectorType, ObservableType
7+
8+
from ._dialect import dialect
9+
10+
11+
@statement
12+
class ConsumesMeasurementResults(ir.Statement):
13+
traits = frozenset({lowering.FromPythonCall()})
14+
measurements: ir.SSAValue = info.argument(
15+
ilist.IListType[MeasurementResultType, kirin_types.Any]
16+
)
17+
18+
19+
@statement(dialect=dialect)
20+
class SetDetector(ConsumesMeasurementResults):
21+
coordinates: ir.SSAValue = info.argument(
22+
type=ilist.IListType[kirin_types.Int | kirin_types.Float, kirin_types.Any]
23+
)
24+
result: ir.ResultValue = info.result(DetectorType)
25+
26+
27+
@statement(dialect=dialect)
28+
class SetObservable(ConsumesMeasurementResults):
29+
result: ir.ResultValue = info.result(ObservableType)

src/bloqade/annotate/types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from kirin import types
2+
3+
4+
class Detector:
5+
pass
6+
7+
8+
class Observable:
9+
pass
10+
11+
12+
DetectorType = types.PyClass(Detector)
13+
ObservableType = types.PyClass(Observable)

src/bloqade/squin/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
noise as noise,
44
analysis as analysis,
55
)
6-
from .. import qubit as qubit
6+
from .. import qubit as qubit, annotate as annotate
77
from ..qubit import (
88
reset as reset,
99
qalloc as qalloc,
@@ -12,6 +12,7 @@
1212
get_measurement_id as get_measurement_id,
1313
)
1414
from .groups import kernel as kernel
15+
from ..annotate import set_detector as set_detector, set_observable as set_observable
1516
from .stdlib.simple import (
1617
h as h,
1718
s as s,

src/bloqade/squin/groups.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from kirin.dialects import debug, ilist
44

55
from . import gate, noise
6-
from .. import qubit
6+
from .. import qubit, annotate
77

88

9-
@ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug]))
9+
@ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug, annotate]))
1010
def kernel(self):
1111
fold_pass = passes.Fold(self)
1212
typeinfer_pass = passes.TypeInfer(self)

src/bloqade/stim/passes/squin_to_stim.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from bloqade.analysis.measure_id import MeasurementIDAnalysis
2828
from bloqade.stim.passes.flatten import Flatten
2929

30-
from ..rewrite.ifs_to_stim import IfToStim
30+
from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim
3131

3232

3333
@dataclass
@@ -64,6 +64,8 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
6464
rewrite_result = (
6565
Chain(
6666
Walk(IfToStim(measure_frame=meas_analysis_frame)),
67+
Walk(SetDetectorToStim(measure_id_frame=meas_analysis_frame)),
68+
Walk(SetObservableToStim(measure_id_frame=meas_analysis_frame)),
6769
Fixpoint(Walk(DeadCodeElimination())),
6870
)
6971
.rewrite(mt.code)

0 commit comments

Comments
 (0)