Skip to content

Commit 62769e7

Browse files
committed
Merge branch 'main' into john/measurement-result-predicate
2 parents 600a48a + 94505d5 commit 62769e7

30 files changed

+834
-5
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ repos:
77
- id: check-yaml
88
args: ['--unsafe']
99
- id: end-of-file-fixer
10+
exclude: .*\.stim$
1011
- id: trailing-whitespace
1112
- repo: https://github.com/pycqa/isort
1213
rev: 6.0.1

src/bloqade/analysis/measure_id/impls.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.analysis import const
33
from kirin.dialects import py, scf, func, ilist
44

5-
from bloqade import qubit
5+
from bloqade import qubit, annotate
66

77
from .lattice import (
88
Predicate,
@@ -83,6 +83,20 @@ def measurement_predicate(
8383
return (MeasureIdTuple(data=tuple(predicate_measure_ids)),)
8484

8585

86+
@annotate.dialect.register(key="measure_id")
87+
class Annotate(interp.MethodTable):
88+
@interp.impl(annotate.stmts.SetObservable)
89+
@interp.impl(annotate.stmts.SetDetector)
90+
def consumes_measurement_results(
91+
self,
92+
interp: MeasurementIDAnalysis,
93+
frame: MeasureIDFrame,
94+
stmt: annotate.stmts.SetObservable | annotate.stmts.SetDetector,
95+
):
96+
frame.num_measures_at_stmt[stmt] = interp.measure_count
97+
return (NotMeasureId(),)
98+
99+
86100
@ilist.dialect.register(key="measure_id")
87101
class IList(interp.MethodTable):
88102
@interp.impl(ilist.New)

src/bloqade/annotate/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from . import stmts as stmts
2+
from ._dialect import dialect as dialect
3+
from ._interface import (
4+
set_detector as set_detector,
5+
set_observable as set_observable,
6+
)

src/bloqade/annotate/_dialect.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect("squin.annotate")

src/bloqade/annotate/_interface.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Any
2+
3+
from kirin.dialects import ilist
4+
from kirin.lowering import wraps
5+
6+
from bloqade.types import MeasurementResult
7+
8+
from .stmts import SetDetector, SetObservable
9+
from .types import Detector, Observable
10+
11+
12+
@wraps(SetDetector)
13+
def set_detector(
14+
measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
15+
coordinates: ilist.IList[float | int, Any] | list[float | int],
16+
) -> Detector: ...
17+
18+
19+
@wraps(SetObservable)
20+
def set_observable(
21+
measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
22+
) -> Observable: ...

src/bloqade/annotate/stmts.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from kirin import ir, types as kirin_types, lowering
2+
from kirin.decl import info, statement
3+
from kirin.dialects import ilist
4+
5+
from bloqade.types import MeasurementResultType
6+
from bloqade.annotate.types import DetectorType, ObservableType
7+
8+
from ._dialect import dialect
9+
10+
11+
@statement
12+
class ConsumesMeasurementResults(ir.Statement):
13+
traits = frozenset({lowering.FromPythonCall()})
14+
measurements: ir.SSAValue = info.argument(
15+
ilist.IListType[MeasurementResultType, kirin_types.Any]
16+
)
17+
18+
19+
@statement(dialect=dialect)
20+
class SetDetector(ConsumesMeasurementResults):
21+
coordinates: ir.SSAValue = info.argument(
22+
type=ilist.IListType[kirin_types.Int | kirin_types.Float, kirin_types.Any]
23+
)
24+
result: ir.ResultValue = info.result(DetectorType)
25+
26+
27+
@statement(dialect=dialect)
28+
class SetObservable(ConsumesMeasurementResults):
29+
result: ir.ResultValue = info.result(ObservableType)

src/bloqade/annotate/types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from kirin import types
2+
3+
4+
class Detector:
5+
pass
6+
7+
8+
class Observable:
9+
pass
10+
11+
12+
DetectorType = types.PyClass(Detector)
13+
ObservableType = types.PyClass(Observable)

src/bloqade/rewrite/passes/aggressive_unroll.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
4242
ilist.rewrite.InlineGetItem(),
4343
ilist.rewrite.FlattenAdd(),
4444
ilist.rewrite.HintLen(),
45+
DeadCodeElimination(),
4546
)
4647
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
4748

src/bloqade/rewrite/passes/canonicalize_ilist.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Walk,
66
Chain,
77
Fixpoint,
8+
DeadCodeElimination,
89
)
910
from kirin.dialects.ilist import rewrite
1011

@@ -24,6 +25,7 @@ def unsafe_run(self, mt: ir.Method):
2425
rewrite.InlineGetItem(),
2526
rewrite.FlattenAdd(),
2627
rewrite.HintLen(),
28+
DeadCodeElimination(),
2729
)
2830
)
2931
).rewrite(mt.code)

src/bloqade/squin/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
noise as noise,
44
analysis as analysis,
55
)
6-
from .. import qubit as qubit
6+
from .. import qubit as qubit, annotate as annotate
77
from ..qubit import (
88
reset as reset,
99
is_one as is_one,
@@ -15,6 +15,7 @@
1515
get_measurement_id as get_measurement_id,
1616
)
1717
from .groups import kernel as kernel
18+
from ..annotate import set_detector as set_detector, set_observable as set_observable
1819
from .stdlib.simple import (
1920
h as h,
2021
s as s,

0 commit comments

Comments
 (0)