-
Notifications
You must be signed in to change notification settings - Fork 1
detector/observable annotation dialect #603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
e85223f
initial annotate dialect implementation
johnzl-777 0a0abcf
get CI to be happy with initial test
johnzl-777 f19345c
bring in return types
johnzl-777 98be61d
add in unit tests
johnzl-777 0ba5c33
get linter to stop messing up .stim files
johnzl-777 3b095d9
Merge branch 'main' into john/port-annotate-dialect
johnzl-777 b46eda9
change up coordinate type and update tests
johnzl-777 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| from . import stmts as stmts | ||
| from ._dialect import dialect as dialect | ||
| from ._interface import ( | ||
| set_detector as set_detector, | ||
| set_observable as set_observable, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from kirin import ir | ||
|
|
||
| dialect = ir.Dialect("squin.annotate") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| from typing import Any | ||
|
|
||
| from kirin.dialects import ilist | ||
| from kirin.lowering import wraps | ||
|
|
||
| from bloqade.types import MeasurementResult | ||
|
|
||
| from .stmts import SetDetector, SetObservable | ||
| from .types import Detector, Observable | ||
|
|
||
|
|
||
| @wraps(SetDetector) | ||
| def set_detector( | ||
| measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult], | ||
| coordinates: tuple[float | int, ...], | ||
| ) -> Detector: ... | ||
|
|
||
|
|
||
| @wraps(SetObservable) | ||
| def set_observable( | ||
| measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult], | ||
| ) -> Observable: ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| from kirin import ir, types as kirin_types, lowering | ||
| from kirin.decl import info, statement | ||
| from kirin.dialects import ilist | ||
|
|
||
| from bloqade.types import MeasurementResultType | ||
| from bloqade.annotate.types import DetectorType, ObservableType | ||
|
|
||
| from ._dialect import dialect | ||
|
|
||
|
|
||
| @statement | ||
| class ConsumesMeasurementResults(ir.Statement): | ||
| traits = frozenset({lowering.FromPythonCall()}) | ||
| measurements: ir.SSAValue = info.argument( | ||
| ilist.IListType[MeasurementResultType, kirin_types.Any] | ||
| ) | ||
|
|
||
|
|
||
| @statement(dialect=dialect) | ||
| class SetDetector(ConsumesMeasurementResults): | ||
| coordinates: ir.SSAValue = info.argument( | ||
| type=kirin_types.Tuple[kirin_types.Int | kirin_types.Float] | ||
| ) | ||
| result: ir.ResultValue = info.result(DetectorType) | ||
|
|
||
|
|
||
| @statement(dialect=dialect) | ||
| class SetObservable(ConsumesMeasurementResults): | ||
| result: ir.ResultValue = info.result(ObservableType) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| from kirin import types | ||
|
|
||
|
|
||
| class Detector: | ||
| pass | ||
|
|
||
|
|
||
| class Observable: | ||
| pass | ||
|
|
||
|
|
||
| DetectorType = types.PyClass(Detector) | ||
| ObservableType = types.PyClass(Observable) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| from kirin import ir | ||
| from kirin.dialects import py | ||
|
|
||
| from bloqade.stim.dialects import auxiliary | ||
| from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple | ||
|
|
||
|
|
||
| def insert_get_records( | ||
| node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count_at_stmt: int | ||
| ): | ||
| """ | ||
| Insert GetRecord statements before the given node | ||
| """ | ||
| get_record_ssas = [] | ||
| for measure_id_bool in measure_id_tuple.data: | ||
| assert isinstance(measure_id_bool, MeasureIdBool) | ||
| target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt | ||
| idx_stmt = py.constant.Constant(target_rec_idx) | ||
| idx_stmt.insert_before(node) | ||
| get_record_stmt = auxiliary.GetRecord(idx_stmt.result) | ||
| get_record_stmt.insert_before(node) | ||
| get_record_ssas.append(get_record_stmt.result) | ||
|
|
||
| return get_record_ssas |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| from typing import Iterable | ||
| from dataclasses import dataclass | ||
|
|
||
| from kirin import ir | ||
| from kirin.dialects.py import Constant | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade.stim.dialects import auxiliary | ||
| from bloqade.annotate.stmts import SetDetector | ||
| from bloqade.analysis.measure_id import MeasureIDFrame | ||
| from bloqade.stim.dialects.auxiliary import Detector | ||
| from bloqade.analysis.measure_id.lattice import MeasureIdTuple | ||
|
|
||
| from ..rewrite.get_record_util import insert_get_records | ||
|
|
||
|
|
||
| @dataclass | ||
| class SetDetectorToStim(RewriteRule): | ||
| """ | ||
| Rewrite SetDetector to GetRecord and Detector in the stim dialect | ||
| """ | ||
|
|
||
| measure_id_frame: MeasureIDFrame | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
| match node: | ||
| case SetDetector(): | ||
| return self.rewrite_SetDetector(node) | ||
| case _: | ||
| return RewriteResult() | ||
|
|
||
| def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult: | ||
|
|
||
| # get coordinates and generate correct consts | ||
| coord_ssas = [] | ||
| if not isinstance(node.coordinates.owner, Constant): | ||
| return RewriteResult() | ||
|
|
||
| coord_values = node.coordinates.owner.value.unwrap() | ||
|
|
||
| if not isinstance(coord_values, Iterable): | ||
| return RewriteResult() | ||
|
|
||
| if any(not isinstance(value, (int, float)) for value in coord_values): | ||
| return RewriteResult() | ||
|
|
||
| for coord_value in coord_values: | ||
| if isinstance(coord_value, float): | ||
| coord_stmt = auxiliary.ConstFloat(value=coord_value) | ||
| else: # int | ||
| coord_stmt = auxiliary.ConstInt(value=coord_value) | ||
| coord_ssas.append(coord_stmt.result) | ||
| coord_stmt.insert_before(node) | ||
|
|
||
| measure_ids = self.measure_id_frame.entries[node.measurements] | ||
| assert isinstance(measure_ids, MeasureIdTuple) | ||
|
|
||
| get_record_list = insert_get_records( | ||
| node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node] | ||
| ) | ||
|
|
||
| detector_stmt = Detector( | ||
| coord=tuple(coord_ssas), targets=tuple(get_record_list) | ||
| ) | ||
|
|
||
| node.replace_by(detector_stmt) | ||
|
|
||
| return RewriteResult(has_done_something=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| from dataclasses import dataclass | ||
|
|
||
| from kirin import ir | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade.stim.dialects import auxiliary | ||
| from bloqade.annotate.stmts import SetObservable | ||
| from bloqade.analysis.measure_id import MeasureIDFrame | ||
| from bloqade.stim.dialects.auxiliary import ObservableInclude | ||
| from bloqade.analysis.measure_id.lattice import MeasureIdTuple | ||
|
|
||
| from ..rewrite.get_record_util import insert_get_records | ||
|
|
||
|
|
||
| @dataclass | ||
| class SetObservableToStim(RewriteRule): | ||
| """ | ||
| Rewrite SetObservable to GetRecord and ObservableInclude in the stim dialect | ||
| """ | ||
|
|
||
| measure_id_frame: MeasureIDFrame | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
| match node: | ||
| case SetObservable(): | ||
| return self.rewrite_SetObservable(node) | ||
| case _: | ||
| return RewriteResult() | ||
|
|
||
| def rewrite_SetObservable(self, node: SetObservable) -> RewriteResult: | ||
|
|
||
| # set idx to 0 for now, but this | ||
| # should be something that a user can set on their own. | ||
| # SetObservable needs to accept an int. | ||
|
|
||
| idx_stmt = auxiliary.ConstInt(value=0) | ||
| idx_stmt.insert_before(node) | ||
|
|
||
| measure_ids = self.measure_id_frame.entries[node.measurements] | ||
| assert isinstance(measure_ids, MeasureIdTuple) | ||
|
|
||
| get_record_list = insert_get_records( | ||
| node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node] | ||
| ) | ||
|
|
||
| observable_include_stmt = ObservableInclude( | ||
| idx=idx_stmt.result, targets=tuple(get_record_list) | ||
| ) | ||
|
|
||
| node.replace_by(observable_include_stmt) | ||
|
|
||
| return RewriteResult(has_done_something=True) |
2 changes: 2 additions & 0 deletions
2
test/stim/passes/stim_reference_programs/annotate/broadcast_with_alias.stim
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
|
|
||
| X 0 1 |
38 changes: 38 additions & 0 deletions
38
test/stim/passes/stim_reference_programs/annotate/clean_surface_code.stim
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
|
|
||
| RZ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | ||
| H 13 14 15 16 | ||
| CX 13 14 16 1 3 7 0 2 4 10 11 12 | ||
| CX 13 14 16 2 4 8 3 5 7 10 11 12 | ||
| CX 13 15 16 0 4 6 1 3 5 9 10 11 | ||
| CX 13 15 16 1 5 7 4 6 8 9 10 11 | ||
| H 13 14 15 16 | ||
| MZ(0.00000000) 13 14 15 16 | ||
| MZ(0.00000000) 9 10 11 12 | ||
| RZ 13 14 15 16 9 10 11 12 | ||
| DETECTOR(0, 0) rec[-4] | ||
| DETECTOR(0, 0) rec[-3] | ||
| DETECTOR(0, 0) rec[-2] | ||
| DETECTOR(0, 0) rec[-1] | ||
| H 13 14 15 16 | ||
| CX 13 14 16 1 3 7 0 2 4 10 11 12 | ||
| CX 13 14 16 2 4 8 3 5 7 10 11 12 | ||
| CX 13 15 16 0 4 6 1 3 5 9 10 11 | ||
| CX 13 15 16 1 5 7 4 6 8 9 10 11 | ||
| H 13 14 15 16 | ||
| MZ(0.00000000) 13 14 15 16 | ||
| MZ(0.00000000) 9 10 11 12 | ||
| RZ 13 14 15 16 9 10 11 12 | ||
| DETECTOR(0, 0) rec[-4] rec[-12] | ||
| DETECTOR(0, 0) rec[-3] rec[-11] | ||
| DETECTOR(0, 0) rec[-2] rec[-10] | ||
| DETECTOR(0, 0) rec[-1] rec[-9] | ||
| DETECTOR(0, 0) rec[-8] rec[-16] | ||
| DETECTOR(0, 0) rec[-7] rec[-15] | ||
| DETECTOR(0, 0) rec[-6] rec[-14] | ||
| DETECTOR(0, 0) rec[-5] rec[-13] | ||
| MZ(0.00000000) 0 1 2 3 4 5 6 7 8 | ||
| DETECTOR(0, 0) rec[-9] rec[-8] rec[-13] | ||
| DETECTOR(0, 0) rec[-8] rec[-7] rec[-5] rec[-4] rec[-12] | ||
| DETECTOR(0, 0) rec[-6] rec[-5] rec[-3] rec[-2] rec[-11] | ||
| DETECTOR(0, 0) rec[-2] rec[-1] rec[-10] | ||
| OBSERVABLE_INCLUDE(0) rec[-9] rec[-8] rec[-7] |
9 changes: 9 additions & 0 deletions
9
test/stim/passes/stim_reference_programs/annotate/detector_coords_as_args.stim
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
|
|
||
| MZ(0.00000000) 0 1 | ||
| DETECTOR(0, 1) rec[-2] rec[-1] | ||
| DETECTOR(3, 4) rec[-2] rec[-1] | ||
| DETECTOR(0, 5.00000000) rec[-2] rec[-1] | ||
| DETECTOR(5.00000000, 3) rec[-2] rec[-1] | ||
| DETECTOR(1, 2, 5.00000000) rec[-2] rec[-1] | ||
| DETECTOR(1, 2) rec[-2] rec[-1] | ||
| DETECTOR(1, 2, 5.00000000) rec[-2] rec[-1] |
6 changes: 6 additions & 0 deletions
6
test/stim/passes/stim_reference_programs/annotate/kernel_base_op.stim
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
|
|
||
| H 1 3 2 | ||
| MZ(0.00000000) 1 3 2 | ||
| DETECTOR(0, 0) rec[-3] | ||
| DETECTOR(0, 0) rec[-2] | ||
| DETECTOR(0, 0) rec[-1] |
11 changes: 11 additions & 0 deletions
11
test/stim/passes/stim_reference_programs/annotate/linear_program_rewrite.stim
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
|
|
||
| X 0 | ||
| Y 1 | ||
| Z 2 | ||
| CX 0 1 | ||
| CX 0 1 2 3 | ||
| Z 0 1 2 3 | ||
| MZ(0.00000000) 0 1 2 3 | ||
| DETECTOR(0.00000000, 0.00000000) rec[-4] rec[-3] | ||
| DETECTOR(1.00000000, 0.00000000) rec[-3] rec[-2] | ||
| OBSERVABLE_INCLUDE(0) rec[-2] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be an
IList? Is the length here known?Judging from the tests, this could be
IListType[types.Float, types.Literal(2)].Do we need to support
Inthere?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this type is wrong.
(0, 3.2, 5)but I think we should actually not allow it moving forward (error when validate in the future)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tupletype actually aliases to a vararg tuple.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, so the right way to hint it (if you really want to use tuple) is:
types.Tuple[types.Float, ...]which is equivalent totypes.Tuple[types.Vararg(types.Float)]Also, for 2. we can discuss if we want that UX. I don't have strong opinion on that.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kaihsin why do we want to prohibit variable length tuple for the coordinates? My understanding is they just exist to make visualization easier, it doesn't affect any functionality.
@david-pl I think the only reason I cooked up tests with two integer coordinates is because I don't think I've ever seen a 3D coordinate used. That being said, I imagine there must be more complicated codes where this would be desirable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kaihsin's point with the tuple length was that ´Tuple[Int]´ indicates a tuple with only a single element. So the correct typing for variable length here would be ´Tuple[Int, ...]´.
I see, then disregard the comment on the length.