-
Notifications
You must be signed in to change notification settings - Fork 1
detector/observable annotation dialect #603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
e85223f
0a0abcf
f19345c
98be61d
0ba5c33
3b095d9
b46eda9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| from . import stmts as stmts | ||
| from ._dialect import dialect as dialect | ||
| from ._interface import ( | ||
| set_detector as set_detector, | ||
| set_observable as set_observable, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from kirin import ir | ||
|
|
||
| dialect = ir.Dialect("squin.annotate") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| from typing import Any | ||
|
|
||
| from kirin.dialects import ilist | ||
| from kirin.lowering import wraps | ||
|
|
||
| from bloqade.types import MeasurementResult | ||
|
|
||
| from .stmts import SetDetector, SetObservable | ||
|
|
||
|
|
||
| @wraps(SetDetector) | ||
| def set_detector( | ||
| measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult], | ||
| coordinates: tuple[float | int, ...], | ||
| ) -> None: ... | ||
|
|
||
|
|
||
| @wraps(SetObservable) | ||
| def set_observable( | ||
| measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult], | ||
| ) -> None: ... | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| from kirin import ir, types as kirin_types, lowering | ||
| from kirin.decl import info, statement | ||
| from kirin.dialects import ilist | ||
|
|
||
| from bloqade.types import MeasurementResultType | ||
|
|
||
| from ._dialect import dialect | ||
|
|
||
|
|
||
| @statement | ||
| class ConsumesMeasurementResults(ir.Statement): | ||
| traits = frozenset({lowering.FromPythonCall()}) | ||
| inputs: ir.SSAValue = info.argument( | ||
| ilist.IListType[MeasurementResultType, kirin_types.Any] | ||
| ) | ||
|
|
||
|
|
||
| @statement(dialect=dialect) | ||
| class SetDetector(ConsumesMeasurementResults): | ||
| coordinates: ir.SSAValue = info.argument( | ||
| type=kirin_types.Tuple[kirin_types.Int | kirin_types.Float] | ||
|
||
| ) | ||
|
|
||
|
|
||
| @statement(dialect=dialect) | ||
| class SetObservable(ConsumesMeasurementResults): | ||
| pass | ||
johnzl-777 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| from kirin import ir | ||
| from kirin.dialects import py | ||
|
|
||
| from bloqade.stim.dialects import auxiliary | ||
| from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple | ||
|
|
||
|
|
||
| def insert_get_records( | ||
| node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count_at_stmt: int | ||
| ): | ||
| """ | ||
| Insert GetRecord statements before the given node | ||
| """ | ||
| get_record_ssas = [] | ||
| for measure_id_bool in measure_id_tuple.data: | ||
| assert isinstance(measure_id_bool, MeasureIdBool) | ||
| target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt | ||
| idx_stmt = py.constant.Constant(target_rec_idx) | ||
| idx_stmt.insert_before(node) | ||
| get_record_stmt = auxiliary.GetRecord(idx_stmt.result) | ||
| get_record_stmt.insert_before(node) | ||
| get_record_ssas.append(get_record_stmt.result) | ||
|
|
||
| return get_record_ssas |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| from typing import Iterable | ||
| from dataclasses import dataclass | ||
|
|
||
| from kirin import ir | ||
| from kirin.dialects.py import Constant | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade.stim.dialects import auxiliary | ||
| from bloqade.annotate.stmts import SetDetector | ||
| from bloqade.analysis.measure_id import MeasureIDFrame | ||
| from bloqade.stim.dialects.auxiliary import Detector | ||
| from bloqade.analysis.measure_id.lattice import MeasureIdTuple | ||
|
|
||
| from ..rewrite.get_record_util import insert_get_records | ||
|
|
||
|
|
||
| @dataclass | ||
| class SetDetectorToStim(RewriteRule): | ||
| """ | ||
| Rewrite SetDetector to GetRecord and Detector in the stim dialect | ||
| """ | ||
|
|
||
| measure_id_frame: MeasureIDFrame | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
| match node: | ||
| case SetDetector(): | ||
| return self.rewrite_SetDetector(node) | ||
| case _: | ||
| return RewriteResult() | ||
|
|
||
| def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult: | ||
|
|
||
| # get coordinates and generate correct consts | ||
| coord_ssas = [] | ||
| if not isinstance(node.coordinates.owner, Constant): | ||
| return RewriteResult() | ||
|
|
||
| coord_values = node.coordinates.owner.value.unwrap() | ||
|
|
||
| if not isinstance(coord_values, Iterable): | ||
| return RewriteResult() | ||
|
|
||
| if any(not isinstance(value, (int, float)) for value in coord_values): | ||
| return RewriteResult() | ||
|
|
||
| for coord_value in coord_values: | ||
| if isinstance(coord_value, float): | ||
| coord_stmt = auxiliary.ConstFloat(value=coord_value) | ||
| else: # int | ||
| coord_stmt = auxiliary.ConstInt(value=coord_value) | ||
| coord_ssas.append(coord_stmt.result) | ||
| coord_stmt.insert_before(node) | ||
|
|
||
| measure_ids = self.measure_id_frame.entries[node.inputs] | ||
| assert isinstance(measure_ids, MeasureIdTuple) | ||
|
|
||
| get_record_list = insert_get_records( | ||
| node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node] | ||
| ) | ||
|
|
||
| detector_stmt = Detector( | ||
| coord=tuple(coord_ssas), targets=tuple(get_record_list) | ||
| ) | ||
|
|
||
| node.replace_by(detector_stmt) | ||
|
|
||
| return RewriteResult(has_done_something=True) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| from dataclasses import dataclass | ||
|
|
||
| from kirin import ir | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade.stim.dialects import auxiliary | ||
| from bloqade.annotate.stmts import SetObservable | ||
| from bloqade.analysis.measure_id import MeasureIDFrame | ||
| from bloqade.stim.dialects.auxiliary import ObservableInclude | ||
| from bloqade.analysis.measure_id.lattice import MeasureIdTuple | ||
|
|
||
| from ..rewrite.get_record_util import insert_get_records | ||
|
|
||
|
|
||
| @dataclass | ||
| class SetObservableToStim(RewriteRule): | ||
| """ | ||
| Rewrite SetObservable to GetRecord and ObservableInclude in the stim dialect | ||
| """ | ||
|
|
||
| measure_id_frame: MeasureIDFrame | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
| match node: | ||
| case SetObservable(): | ||
| return self.rewrite_SetObservable(node) | ||
| case _: | ||
| return RewriteResult() | ||
|
|
||
| def rewrite_SetObservable(self, node: SetObservable) -> RewriteResult: | ||
|
|
||
| # set idx to 0 for now, but this | ||
| # should be something that a user can set on their own. | ||
| # SetObservable needs to accept an int. | ||
|
|
||
| idx_stmt = auxiliary.ConstInt(value=0) | ||
| idx_stmt.insert_before(node) | ||
|
|
||
| measure_ids = self.measure_id_frame.entries[node.inputs] | ||
| assert isinstance(measure_ids, MeasureIdTuple) | ||
|
|
||
| get_record_list = insert_get_records( | ||
| node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node] | ||
| ) | ||
|
|
||
| observable_include_stmt = ObservableInclude( | ||
| idx=idx_stmt.result, targets=tuple(get_record_list) | ||
| ) | ||
|
|
||
| node.replace_by(observable_include_stmt) | ||
|
|
||
| return RewriteResult(has_done_something=True) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| from kirin import ir | ||
|
|
||
| from bloqade import squin | ||
| from bloqade.stim.emit import EmitStimMain | ||
| from bloqade.stim.passes import SquinToStimPass | ||
|
|
||
|
|
||
| def codegen(mt: ir.Method): | ||
| # method should not have any arguments! | ||
| emit = EmitStimMain() | ||
| emit.initialize() | ||
| emit.run(mt=mt, args=()) | ||
| return emit.get_output() | ||
|
|
||
|
|
||
| def test_annotate(): | ||
|
|
||
| @squin.kernel | ||
| def test(): | ||
| qs = squin.qalloc(4) | ||
| ms = squin.broadcast.measure(qs) | ||
| squin.set_detector([ms[0], ms[1], ms[2]], coordinates=(0, 0)) | ||
| squin.set_observable([ms[3]]) | ||
|
|
||
| SquinToStimPass(dialects=test.dialects)(test) | ||
| codegen_output = codegen(test) | ||
| expected_output = ( | ||
| "\nMZ(0.00000000) 0 1 2 3\n" | ||
| "DETECTOR(0, 0) rec[-4] rec[-3] rec[-2]\n" | ||
| "OBSERVABLE_INCLUDE(0) rec[-1]" | ||
| ) | ||
| assert codegen_output == expected_output |
Uh oh!
There was an error while loading. Please reload this page.