-
Notifications
You must be signed in to change notification settings - Fork 1
squin to stim rewrite #148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 33 commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
4e21ec1
initial steps for squin to stim rewrite
johnzl-777 89a17fd
Wrap rewrite pass
johnzl-777 f6105bb
confirm analysis wrapping works
johnzl-777 ad3a7dc
Merge branch 'main' into 19-rewrite-from-squin-to-stim
johnzl-777 7c70b5e
going to bed
johnzl-777 e55a8fb
preliminary handling of Apply
johnzl-777 e76b3f3
support for control gates confirmed
johnzl-777 7ba5670
finally put everything into a pass
johnzl-777 f3203da
partially working reset rewrite
johnzl-777 20d4214
account for MeasureAndReset
johnzl-777 300f9d7
account for MeasureAndReset, fix up address analysis
johnzl-777 c137221
Merge branch 'main' into 19-rewrite-from-squin-to-stim
johnzl-777 13ae8a5
more testing, verification implemented
johnzl-777 59f763d
remove test call
johnzl-777 9427571
simple site verification test
johnzl-777 e45c10c
saving remaining work before move to codegen
johnzl-777 1a67a03
Merge branch 'main' into 19-rewrite-from-squin-to-stim
johnzl-777 9b6562f
Merge branch 'main' into 19-rewrite-from-squin-to-stim
Roger-luo 30b8a97
account for MeasureQubit, MeasureQubitIlist as well as Broadcast func…
johnzl-777 615a30b
revise tests
johnzl-777 47691e2
remove unnecessary comment
johnzl-777 c8c38ae
Move wrap analysis rewrite into its own file
johnzl-777 72e5f3f
split out reusable utility functions into a seperate file
johnzl-777 4731fe0
fix export problem
johnzl-777 410bcbc
split out rewrite rules, factor in feedback on rewriting wrap/unwrap …
johnzl-777 cab4ab6
get control statement logic to work in wire dialect, simplify stateme…
johnzl-777 c810544
use dict instead of match, just care about type comparison
johnzl-777 fb9ff37
use replace_by, fix analysis impl
johnzl-777 e6c92b7
first round of meeting feedback implemented
johnzl-777 9316a0e
split out Measure rewrite into its own rule
johnzl-777 e3b478a
add tests but codegen is acting weird
johnzl-777 0f759b8
remove site-target dimension check
johnzl-777 877755e
Merge branch 'main' into 19-rewrite-from-squin-to-stim
johnzl-777 4faf282
Merge branch 'main' into 19-rewrite-from-squin-to-stim
johnzl-777 dbc105c
implement second round review feedback
johnzl-777 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .stim import SquinToStim as SquinToStim |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| from dataclasses import dataclass | ||
|
|
||
| from kirin.passes import Fold | ||
| from kirin.rewrite import ( | ||
| Walk, | ||
| Chain, | ||
| Fixpoint, | ||
| DeadCodeElimination, | ||
| CommonSubexpressionElimination, | ||
| ) | ||
| from kirin.ir.method import Method | ||
| from kirin.passes.abc import Pass | ||
| from kirin.rewrite.abc import RewriteResult | ||
|
|
||
| from bloqade.squin.rewrite import ( | ||
| SquinWireToStim, | ||
| SquinQubitToStim, | ||
| WrapSquinAnalysis, | ||
| SquinMeasureToStim, | ||
| SquinWireIdentityElimination, | ||
| ) | ||
| from bloqade.analysis.address import AddressAnalysis | ||
| from bloqade.squin.analysis.nsites import ( | ||
| NSitesAnalysis, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class SquinToStim(Pass): | ||
|
|
||
| def unsafe_run(self, mt: Method) -> RewriteResult: | ||
| fold_pass = Fold(mt.dialects) | ||
| # propagate constants | ||
| rewrite_result = fold_pass(mt) | ||
|
|
||
| # Get necessary analysis results to plug into hints | ||
| address_analysis = AddressAnalysis(mt.dialects) | ||
| address_frame, _ = address_analysis.run_analysis(mt) | ||
| site_analysis = NSitesAnalysis(mt.dialects) | ||
| sites_frame, _ = site_analysis.run_analysis(mt) | ||
|
|
||
| # Wrap Rewrite + SquinToStim can happen w/ standard walk | ||
| rewrite_result = ( | ||
| Walk( | ||
| Chain( | ||
| WrapSquinAnalysis( | ||
| address_analysis=address_frame.entries, | ||
| op_site_analysis=sites_frame.entries, | ||
| ), | ||
| SquinQubitToStim(), | ||
| SquinWireToStim(), | ||
| SquinMeasureToStim(), # reduce duplicated logic, can split out even more rules later | ||
| SquinWireIdentityElimination(), | ||
| ) | ||
| ) | ||
| .rewrite(mt.code) | ||
| .join(rewrite_result) | ||
| ) | ||
|
|
||
| rewrite_result = ( | ||
| Fixpoint( | ||
| Walk(Chain(DeadCodeElimination(), CommonSubexpressionElimination())) | ||
| ) | ||
| .rewrite(mt.code) | ||
| .join(rewrite_result) | ||
| ) | ||
|
|
||
| return rewrite_result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| from .wire_to_stim import SquinWireToStim as SquinWireToStim | ||
| from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim | ||
| from .squin_measure import SquinMeasureToStim as SquinMeasureToStim | ||
| from .wrap_analysis import ( | ||
| SitesAttribute as SitesAttribute, | ||
| AddressAttribute as AddressAttribute, | ||
| WrapSquinAnalysis as WrapSquinAnalysis, | ||
| ) | ||
| from .wire_identity_elimination import ( | ||
| SquinWireIdentityElimination as SquinWireIdentityElimination, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| from kirin import ir | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade import stim | ||
| from bloqade.squin import op, qubit | ||
| from bloqade.squin.rewrite.wrap_analysis import AddressAttribute | ||
| from bloqade.squin.rewrite.stim_rewrite_util import ( | ||
| SQUIN_STIM_GATE_MAPPING, | ||
| rewrite_Control, | ||
| insert_qubit_idx_from_address, | ||
| ) | ||
|
|
||
|
|
||
| class SquinQubitToStim(RewriteRule): | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
|
|
||
| match node: | ||
| case qubit.Apply() | qubit.Broadcast(): | ||
| return self.rewrite_Apply_and_Broadcast(node) | ||
| case qubit.Reset(): | ||
| return self.rewrite_Reset(node) | ||
| case _: | ||
| return RewriteResult() | ||
|
|
||
| def rewrite_Apply_and_Broadcast( | ||
| self, stmt: qubit.Apply | qubit.Broadcast | ||
| ) -> RewriteResult: | ||
| """ | ||
| Rewrite Apply and Broadcast nodes to their stim equivalent statements. | ||
| """ | ||
|
|
||
| # this is an SSAValue, need it to be the actual operator | ||
| applied_op = stmt.operator.owner | ||
| assert isinstance(applied_op, op.stmts.Operator) | ||
|
|
||
| if isinstance(applied_op, op.stmts.Control): | ||
| return rewrite_Control(stmt) | ||
|
|
||
| # need to handle Control through separate means | ||
| # but we can handle X, Y, Z, H, and S here just fine | ||
| stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op)) | ||
| if stim_1q_op is None: | ||
| return RewriteResult() | ||
|
|
||
| address_attr = stmt.qubits.hints.get("address") | ||
| if address_attr is None: | ||
| return RewriteResult() | ||
|
|
||
| assert isinstance(address_attr, AddressAttribute) | ||
| qubit_idx_ssas = insert_qubit_idx_from_address( | ||
| address=address_attr, stmt_to_insert_before=stmt | ||
| ) | ||
|
|
||
| if qubit_idx_ssas is None: | ||
| return RewriteResult() | ||
|
|
||
| stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) | ||
| stmt.replace_by(stim_1q_stmt) | ||
|
|
||
| return RewriteResult(has_done_something=True) | ||
|
|
||
| def rewrite_Reset(self, reset_stmt: qubit.Reset) -> RewriteResult: | ||
| qubit_ilist_ssa = reset_stmt.qubits | ||
| # qubits are in an ilist which makes up an AddressTuple | ||
| address_attr = qubit_ilist_ssa.hints.get("address") | ||
| if address_attr is None: | ||
| return RewriteResult() | ||
|
|
||
| assert isinstance(address_attr, AddressAttribute) | ||
| qubit_idx_ssas = insert_qubit_idx_from_address( | ||
| address=address_attr, stmt_to_insert_before=reset_stmt | ||
| ) | ||
|
|
||
| if qubit_idx_ssas is None: | ||
| return RewriteResult() | ||
|
|
||
| stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas) | ||
| reset_stmt.replace_by(stim_rz_stmt) | ||
|
|
||
| return RewriteResult(has_done_something=True) | ||
|
|
||
|
|
||
| # put rewrites for measure statements in separate rule, then just have to dispatch | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # create rewrite rule name SquinMeasureToStim using kirin | ||
| from kirin import ir | ||
| from kirin.dialects import py | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade import stim | ||
| from bloqade.squin import wire, qubit | ||
| from bloqade.squin.rewrite.wrap_analysis import AddressAttribute | ||
| from bloqade.squin.rewrite.stim_rewrite_util import ( | ||
| is_measure_result_used, | ||
| insert_qubit_idx_from_address, | ||
| ) | ||
|
|
||
|
|
||
| class SquinMeasureToStim(RewriteRule): | ||
| """ | ||
| Rewrite squin measure-related statements to stim statements. | ||
| """ | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
|
|
||
| match node: | ||
| case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure(): | ||
| return self.rewrite_Measure(node) | ||
| case qubit.MeasureAndReset() | wire.MeasureAndReset(): | ||
| return self.rewrite_MeasureAndReset(node) | ||
| case _: | ||
| return RewriteResult() | ||
|
|
||
| def rewrite_Measure( | ||
| self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure | ||
| ) -> RewriteResult: | ||
| if is_measure_result_used(measure_stmt): | ||
| return RewriteResult() | ||
|
|
||
| qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt) | ||
| if qubit_idx_ssas is None: | ||
| return RewriteResult() | ||
|
|
||
| prob_noise_stmt = py.constant.Constant(0.0) | ||
| stim_measure_stmt = stim.collapse.MZ( | ||
| p=prob_noise_stmt.result, | ||
| targets=qubit_idx_ssas, | ||
| ) | ||
| prob_noise_stmt.insert_before(measure_stmt) | ||
| measure_stmt.replace_by(stim_measure_stmt) | ||
|
|
||
| return RewriteResult(has_done_something=True) | ||
|
|
||
| def rewrite_MeasureAndReset( | ||
| self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset | ||
| ) -> RewriteResult: | ||
| if not is_measure_result_used(meas_and_reset_stmt): | ||
| return RewriteResult() | ||
|
|
||
| qubit_idx_ssas = self.get_qubit_idx_ssas(meas_and_reset_stmt) | ||
|
|
||
| if qubit_idx_ssas is None: | ||
| return RewriteResult() | ||
|
|
||
| error_p_stmt = py.Constant(0.0) | ||
| stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result) | ||
| stim_rz_stmt = stim.collapse.RZ( | ||
| targets=qubit_idx_ssas, | ||
| ) | ||
|
|
||
| error_p_stmt.insert_before(meas_and_reset_stmt) | ||
| stim_mz_stmt.insert_before(meas_and_reset_stmt) | ||
| meas_and_reset_stmt.replace_by(stim_rz_stmt) | ||
|
|
||
| return RewriteResult(has_done_something=True) | ||
|
|
||
| def get_qubit_idx_ssas( | ||
| self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure | ||
| ) -> tuple[ir.SSAValue, ...] | None: | ||
| """ | ||
| Extract the address attribute and insert qubit indices for the given measure statement. | ||
| """ | ||
| match measure_stmt: | ||
| case qubit.MeasureQubit(): | ||
| address_attr = measure_stmt.qubit.hints.get("address") | ||
| case qubit.MeasureQubitList(): | ||
| address_attr = measure_stmt.qubits.hints.get("address") | ||
| case wire.Measure(): | ||
| address_attr = measure_stmt.wire.hints.get("address") | ||
| case _: | ||
| return None | ||
|
|
||
| if address_attr is None: | ||
| return None | ||
|
|
||
| assert isinstance(address_attr, AddressAttribute) | ||
|
|
||
| qubit_idx_ssas = insert_qubit_idx_from_address( | ||
| address=address_attr, stmt_to_insert_before=measure_stmt | ||
| ) | ||
|
|
||
| return qubit_idx_ssas | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.