-
Notifications
You must be signed in to change notification settings - Fork 1
Non-Clifford to U3 rewrite. #647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
2468f16
adding rewrite
weinbe58 ae462fc
fixing bugs + adding tests
weinbe58 8eb84b7
Merge branch 'main' into phil/rewrite-non-clifford-u3-2
weinbe58 23da678
Removing default
weinbe58 5d3159b
Merge branch 'phil/rewrite-non-clifford-u3-2' of https://github.com/Q…
weinbe58 84509cc
Merge branch 'main' into phil/rewrite-non-clifford-u3-2
weinbe58 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| from kirin import ir | ||
| from kirin.rewrite import abc as rewrite_abc | ||
| from kirin.dialects import py | ||
|
|
||
| from bloqade.squin.gate import stmts as gate_stmts | ||
|
|
||
|
|
||
| class RewriteNonCliffordToU3(rewrite_abc.RewriteRule): | ||
| """Rewrite non-Clifford gates to U3 gates. | ||
|
|
||
| This rewrite rule transforms specific non-Clifford single-qubit gates | ||
| into equivalent U3 gate representations. The following transformations are applied: | ||
| - T gate (with adjoint attribute) to U3 gate with parameters (0, 0, ±π/4) | ||
| - Rx gate to U3 gate with parameters (angle, -π/2, π/2) | ||
| - Ry gate to U3 gate with parameters (angle, 0, 0) | ||
| - Rz gate is U3 gate with parameters (0, 0, angle) | ||
|
|
||
| This rewrite should be paired with `U3ToClifford` to canonicalize the circuit. | ||
|
|
||
| """ | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult: | ||
| if not isinstance( | ||
| node, | ||
| ( | ||
| gate_stmts.T, | ||
| gate_stmts.Rx, | ||
| gate_stmts.Ry, | ||
| gate_stmts.Rz, | ||
| ), | ||
| ): | ||
| return rewrite_abc.RewriteResult() | ||
|
|
||
| rule = getattr(self, f"rewrite_{type(node).__name__}", self.default) | ||
|
|
||
| return rule(node) | ||
|
|
||
| def default(self, node: ir.Statement) -> rewrite_abc.RewriteResult: | ||
| return rewrite_abc.RewriteResult() | ||
|
|
||
| def rewrite_T(self, node: gate_stmts.T) -> rewrite_abc.RewriteResult: | ||
| if node.adjoint: | ||
| lam_value = -1.0 / 8.0 | ||
| else: | ||
| lam_value = 1.0 / 8.0 | ||
|
|
||
| (theta_stmt := py.Constant(0.0)).insert_before(node) | ||
| (phi_stmt := py.Constant(0.0)).insert_before(node) | ||
| (lam_stmt := py.Constant(lam_value)).insert_before(node) | ||
|
|
||
| node.replace_by( | ||
| gate_stmts.U3( | ||
| qubits=node.qubits, | ||
| theta=theta_stmt.result, | ||
| phi=phi_stmt.result, | ||
| lam=lam_stmt.result, | ||
| ) | ||
| ) | ||
|
|
||
| return rewrite_abc.RewriteResult(has_done_something=True) | ||
|
|
||
| def rewrite_Rx(self, node: gate_stmts.Rx) -> rewrite_abc.RewriteResult: | ||
| (phi_stmt := py.Constant(-0.25)).insert_before(node) | ||
| (lam_stmt := py.Constant(0.25)).insert_before(node) | ||
|
|
||
| node.replace_by( | ||
| gate_stmts.U3( | ||
| qubits=node.qubits, | ||
| theta=node.angle, | ||
| phi=phi_stmt.result, | ||
| lam=lam_stmt.result, | ||
| ) | ||
| ) | ||
|
|
||
| return rewrite_abc.RewriteResult(has_done_something=True) | ||
|
|
||
| def rewrite_Ry(self, node: gate_stmts.Ry) -> rewrite_abc.RewriteResult: | ||
| (phi_stmt := py.Constant(0.0)).insert_before(node) | ||
| (lam_stmt := py.Constant(0.0)).insert_before(node) | ||
|
|
||
| node.replace_by( | ||
| gate_stmts.U3( | ||
| qubits=node.qubits, | ||
| theta=node.angle, | ||
| phi=phi_stmt.result, | ||
| lam=lam_stmt.result, | ||
| ) | ||
| ) | ||
|
|
||
| return rewrite_abc.RewriteResult(has_done_something=True) | ||
|
|
||
| def rewrite_Rz(self, node: gate_stmts.Rz) -> rewrite_abc.RewriteResult: | ||
| (theta_stmt := py.Constant(0.0)).insert_before(node) | ||
| (phi_stmt := py.Constant(0.0)).insert_before(node) | ||
|
|
||
| node.replace_by( | ||
| gate_stmts.U3( | ||
| qubits=node.qubits, | ||
| theta=theta_stmt.result, | ||
| phi=phi_stmt.result, | ||
| lam=node.angle, | ||
| ) | ||
| ) | ||
|
|
||
| return rewrite_abc.RewriteResult(has_done_something=True) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| from kirin import ir, rewrite | ||
| from kirin.dialects import py | ||
|
|
||
| from bloqade.squin.gate import stmts as gate_stmts | ||
| from bloqade.test_utils import assert_nodes | ||
| from bloqade.squin.rewrite.non_clifford_to_U3 import RewriteNonCliffordToU3 | ||
|
|
||
|
|
||
| def test_rewrite_T(): | ||
| test_qubits = ir.TestValue() | ||
| test_block = ir.Block([gate_stmts.T(qubits=test_qubits, adjoint=False)]) | ||
|
|
||
| expected_block = ir.Block( | ||
| [ | ||
| theta := py.Constant(0.0), | ||
| phi := py.Constant(0.0), | ||
| lam := py.Constant(1.0 / 8.0), | ||
| gate_stmts.U3( | ||
| qubits=test_qubits, | ||
| theta=theta.result, | ||
| phi=phi.result, | ||
| lam=lam.result, | ||
| ), | ||
| ] | ||
| ) | ||
|
|
||
| rule = rewrite.Walk(RewriteNonCliffordToU3()) | ||
| rule.rewrite(test_block) | ||
|
|
||
| assert_nodes(test_block, expected_block) | ||
|
|
||
|
|
||
| def test_rewrite_Tadj(): | ||
| test_qubits = ir.TestValue() | ||
| test_block = ir.Block([gate_stmts.T(qubits=test_qubits, adjoint=True)]) | ||
|
|
||
| expected_block = ir.Block( | ||
| [ | ||
| theta := py.Constant(0.0), | ||
| phi := py.Constant(0.0), | ||
| lam := py.Constant(-1.0 / 8.0), | ||
| gate_stmts.U3( | ||
| qubits=test_qubits, | ||
| theta=theta.result, | ||
| phi=phi.result, | ||
| lam=lam.result, | ||
| ), | ||
| ] | ||
| ) | ||
|
|
||
| rule = rewrite.Walk(RewriteNonCliffordToU3()) | ||
| rule.rewrite(test_block) | ||
|
|
||
| assert_nodes(test_block, expected_block) | ||
|
|
||
|
|
||
| def test_rewrite_Ry(): | ||
| test_qubits = ir.TestValue() | ||
| angle = ir.TestValue() | ||
| test_block = ir.Block([gate_stmts.Ry(qubits=test_qubits, angle=angle)]) | ||
|
|
||
| expected_block = ir.Block( | ||
| [ | ||
| phi := py.Constant(0.0), | ||
| lam := py.Constant(0.0), | ||
| gate_stmts.U3( | ||
| qubits=test_qubits, | ||
| theta=angle, | ||
| phi=phi.result, | ||
| lam=lam.result, | ||
| ), | ||
| ] | ||
| ) | ||
|
|
||
| rule = rewrite.Walk(RewriteNonCliffordToU3()) | ||
| rule.rewrite(test_block) | ||
|
|
||
| assert_nodes(test_block, expected_block) | ||
|
|
||
|
|
||
| def test_rewrite_Rz(): | ||
| test_qubits = ir.TestValue() | ||
| angle = ir.TestValue() | ||
| test_block = ir.Block([gate_stmts.Rz(qubits=test_qubits, angle=angle)]) | ||
|
|
||
| expected_block = ir.Block( | ||
| [ | ||
| theta := py.Constant(0.0), | ||
| phi := py.Constant(0.0), | ||
| gate_stmts.U3( | ||
| qubits=test_qubits, | ||
| theta=theta.result, | ||
| phi=phi.result, | ||
| lam=angle, | ||
| ), | ||
| ] | ||
| ) | ||
|
|
||
| rule = rewrite.Walk(RewriteNonCliffordToU3()) | ||
| rule.rewrite(test_block) | ||
|
|
||
| assert_nodes(test_block, expected_block) | ||
|
|
||
|
|
||
| def test_rewrite_Rx(): | ||
| test_qubits = ir.TestValue() | ||
| angle = ir.TestValue() | ||
| test_block = ir.Block([gate_stmts.Rx(qubits=test_qubits, angle=angle)]) | ||
|
|
||
| expected_block = ir.Block( | ||
| [ | ||
| phi := py.Constant(-0.25), | ||
| lam := py.Constant(0.25), | ||
| gate_stmts.U3( | ||
| qubits=test_qubits, | ||
| theta=angle, | ||
| phi=phi.result, | ||
| lam=lam.result, | ||
| ), | ||
| ] | ||
| ) | ||
|
|
||
| rule = rewrite.Walk(RewriteNonCliffordToU3()) | ||
| rule.rewrite(test_block) | ||
|
|
||
| assert_nodes(test_block, expected_block) | ||
|
|
||
|
|
||
| def test_no_op(): | ||
| test_qubits = ir.TestValue() | ||
| angle = ir.TestValue() | ||
| test_block = ir.Block( | ||
| [gate_stmts.U3(qubits=test_qubits, theta=angle, phi=angle, lam=angle)] | ||
| ) | ||
|
|
||
| expected_block = ir.Block( | ||
| [gate_stmts.U3(qubits=test_qubits, theta=angle, phi=angle, lam=angle)] | ||
| ) | ||
|
|
||
| rule = rewrite.Walk(RewriteNonCliffordToU3()) | ||
| rule.rewrite(test_block) | ||
|
|
||
| assert_nodes(test_block, expected_block) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.