diff --git a/src/bloqade/native/stdlib/broadcast.py b/src/bloqade/native/stdlib/broadcast.py index e8942a9a..f890edc5 100644 --- a/src/bloqade/native/stdlib/broadcast.py +++ b/src/bloqade/native/stdlib/broadcast.py @@ -21,6 +21,30 @@ def _radian_to_turn(angle: float) -> float: return angle / (2 * math.pi) +@kernel +def _rx_turns(angle: float, qubits: ilist.IList[qubit.Qubit, Any]): + native.r(0.0, angle, qubits) + + +@kernel +def _ry_turns(angle: float, qubits: ilist.IList[qubit.Qubit, Any]): + native.r(0.25, angle, qubits) + + +@kernel +def _rz_turns(angle: float, qubits: ilist.IList[qubit.Qubit, Any]): + native.rz(angle, qubits) + + +@kernel +def _u3_turns( + theta: float, phi: float, lam: float, qubits: ilist.IList[qubit.Qubit, Any] +): + _rz_turns(lam, qubits) + _ry_turns(theta, qubits) + _rz_turns(phi, qubits) + + @kernel def rx(angle: float, qubits: ilist.IList[qubit.Qubit, Any]): """Apply an RX rotation gate on a group of qubits. @@ -29,7 +53,7 @@ def rx(angle: float, qubits: ilist.IList[qubit.Qubit, Any]): angle (float): Rotation angle in radians. qubits (ilist.IList[qubit.Qubit, Any]): Target qubits. """ - native.r(0.0, _radian_to_turn(angle), qubits) + _rx_turns(_radian_to_turn(angle), qubits) @kernel @@ -70,7 +94,7 @@ def ry(angle: float, qubits: ilist.IList[qubit.Qubit, Any]): angle (float): Rotation angle in radians. qubits (ilist.IList[qubit.Qubit, Any]): Target qubits. """ - native.r(0.25, _radian_to_turn(angle), qubits) + _ry_turns(_radian_to_turn(angle), qubits) @kernel @@ -111,7 +135,7 @@ def rz(angle: float, qubits: ilist.IList[qubit.Qubit, Any]): angle (float): Rotation angle in radians. qubits (ilist.IList[qubit.Qubit, Any]): Target qubits. """ - native.rz(_radian_to_turn(angle), qubits) + _rz_turns(_radian_to_turn(angle), qubits) @kernel @@ -201,9 +225,9 @@ def u3(theta: float, phi: float, lam: float, qubits: ilist.IList[qubit.Qubit, An lam (float): Z rotations in decomposition (radians). qubits (ilist.IList[qubit.Qubit, Any]): Target qubits. """ - rz(lam, qubits) - ry(theta, qubits) - rz(phi, qubits) + _u3_turns( + _radian_to_turn(theta), _radian_to_turn(phi), _radian_to_turn(lam), qubits + ) N = TypeVar("N") diff --git a/src/bloqade/native/upstream/squin2native.py b/src/bloqade/native/upstream/squin2native.py index 998d34b6..74d7d497 100644 --- a/src/bloqade/native/upstream/squin2native.py +++ b/src/bloqade/native/upstream/squin2native.py @@ -20,13 +20,13 @@ class GateRule(RewriteRule): stmts.T: (broadcast.t, broadcast.t_adj), stmts.SqrtX: (broadcast.sqrt_x, broadcast.sqrt_x_adj), stmts.SqrtY: (broadcast.sqrt_y, broadcast.sqrt_y_adj), - stmts.Rx: (broadcast.rx,), - stmts.Ry: (broadcast.ry,), - stmts.Rz: (broadcast.rz,), + stmts.Rx: (broadcast._rx_turns,), + stmts.Ry: (broadcast._ry_turns,), + stmts.Rz: (broadcast._rz_turns,), stmts.CX: (broadcast.cx,), stmts.CY: (broadcast.cy,), stmts.CZ: (broadcast.cz,), - stmts.U3: (broadcast.u3,), + stmts.U3: (broadcast._u3_turns,), } def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: diff --git a/src/bloqade/rewrite/passes/callgraph.py b/src/bloqade/rewrite/passes/callgraph.py index 0b5e64f0..5a46e579 100644 --- a/src/bloqade/rewrite/passes/callgraph.py +++ b/src/bloqade/rewrite/passes/callgraph.py @@ -98,8 +98,8 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: mt_map = {} cg = CallGraph(mt) - all_methods = set(cg.edges.keys()) + all_methods.add(mt) for original_mt in all_methods: if original_mt is mt: new_mt = original_mt diff --git a/test/native/upstream/test_squin2native.py b/test/native/upstream/test_squin2native.py index 7d250e67..a9d26d39 100644 --- a/test/native/upstream/test_squin2native.py +++ b/test/native/upstream/test_squin2native.py @@ -2,9 +2,10 @@ import pytest from kirin.analysis import callgraph -from bloqade import squin +from bloqade import squin, native, test_utils from bloqade.squin import gate from bloqade.pyqrack import StackMemorySimulator +from bloqade.rewrite.passes import AggressiveUnroll from bloqade.native.dialects import gate as native_gate from bloqade.native.upstream import GateRule, SquinToNative @@ -48,3 +49,57 @@ def main(): new_sv /= new_sv[imax := np.abs(new_sv).argmax()] / np.abs(new_sv[imax]) assert np.allclose(old_sv, new_sv) + + +@pytest.mark.parametrize( + ("squin_gate", "native_gate"), + [ + (squin.rz, native.rz), + (squin.rx, native.rx), + (squin.ry, native.ry), + ], +) +def test_pipeline(squin_gate, native_gate): + + @squin.kernel + def ghz(angle: float): + qubits = squin.qalloc(1) + squin_gate(angle, qubits[0]) + + @native.kernel + def ghz_native(angle: float): + qubits = squin.qalloc(1) + native_gate(angle, qubits[0]) + + ghz_native_rewrite = SquinToNative().emit(ghz) + AggressiveUnroll(ghz.dialects).fixpoint(ghz_native_rewrite) + + AggressiveUnroll(ghz_native.dialects).fixpoint(ghz_native) + test_utils.assert_nodes( + ghz_native_rewrite.callable_region.blocks[0], + ghz_native.callable_region.blocks[0], + ) + + +def test_pipeline_u3(): + + @squin.kernel + def ghz(theta: float, phi: float, lam: float): + qubits = squin.qalloc(1) + squin.u3(theta, phi, lam, qubits[0]) + + @native.kernel + def ghz_native(theta: float, phi: float, lam: float): + qubits = squin.qalloc(1) + native.u3(theta, phi, lam, qubits[0]) + + # unroll first to check that gete rewrites happen in ghz body + AggressiveUnroll(ghz.dialects).fixpoint(ghz) + ghz_native_rewrite = SquinToNative().emit(ghz) + AggressiveUnroll(ghz.dialects).fixpoint(ghz_native_rewrite) + + AggressiveUnroll(ghz_native.dialects).fixpoint(ghz_native) + test_utils.assert_nodes( + ghz_native_rewrite.callable_region.blocks[0], + ghz_native.callable_region.blocks[0], + )