Skip to content

Commit dc6a6f3

Browse files
committed
Refactoring stdlib to fix rewrite
1 parent 5b853b7 commit dc6a6f3

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

src/bloqade/native/stdlib/broadcast.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,21 @@ def _radian_to_turn(angle: float) -> float:
2121
return angle / (2 * math.pi)
2222

2323

24+
@kernel
25+
def _rx_turns(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
26+
native.r(0.0, angle, qubits)
27+
28+
29+
@kernel
30+
def _ry_turns(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
31+
native.r(0.25, angle, qubits)
32+
33+
34+
@kernel
35+
def _rz_turns(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
36+
native.r(0.5, angle, qubits)
37+
38+
2439
@kernel
2540
def rx(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
2641
"""Apply an RX rotation gate on a group of qubits.
@@ -29,7 +44,7 @@ def rx(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
2944
angle (float): Rotation angle in radians.
3045
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
3146
"""
32-
native.r(0.0, _radian_to_turn(angle), qubits)
47+
_rx_turns(_radian_to_turn(angle), qubits)
3348

3449

3550
@kernel
@@ -70,7 +85,7 @@ def ry(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
7085
angle (float): Rotation angle in radians.
7186
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
7287
"""
73-
native.r(0.25, _radian_to_turn(angle), qubits)
88+
_ry_turns(_radian_to_turn(angle), qubits)
7489

7590

7691
@kernel

src/bloqade/native/upstream/squin2native.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ class GateRule(RewriteRule):
2020
stmts.T: (broadcast.t, broadcast.t_adj),
2121
stmts.SqrtX: (broadcast.sqrt_x, broadcast.sqrt_x_adj),
2222
stmts.SqrtY: (broadcast.sqrt_y, broadcast.sqrt_y_adj),
23-
stmts.Rx: (broadcast.rx,),
24-
stmts.Ry: (broadcast.ry,),
25-
stmts.Rz: (broadcast.rz,),
23+
stmts.Rx: (broadcast._rx_turns,),
24+
stmts.Ry: (broadcast._ry_turns,),
25+
stmts.Rz: (broadcast._rz_turns,),
2626
stmts.CX: (broadcast.cx,),
2727
stmts.CY: (broadcast.cy,),
2828
stmts.CZ: (broadcast.cz,),

test/native/upstream/test_squin2native.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import pytest
33
from kirin.analysis import callgraph
44

5-
from bloqade import squin
5+
from bloqade import squin, native
66
from bloqade.squin import gate
77
from bloqade.pyqrack import StackMemorySimulator
8+
from bloqade.test_utils import assert_methods
9+
from bloqade.rewrite.passes import AggressiveUnroll
810
from bloqade.native.dialects import gate as native_gate
911
from bloqade.native.upstream import GateRule, SquinToNative
1012

@@ -48,3 +50,24 @@ def main():
4850
new_sv /= new_sv[imax := np.abs(new_sv).argmax()] / np.abs(new_sv[imax])
4951

5052
assert np.allclose(old_sv, new_sv)
53+
54+
55+
def test_pipeline():
56+
57+
@squin.kernel
58+
def ghz(angle: float):
59+
qubits = squin.qalloc(1)
60+
squin.rz(angle, qubits[0])
61+
62+
@native.kernel
63+
def ghz_native(angle: float):
64+
qubits = squin.qalloc(1)
65+
native.rz(angle, qubits[0])
66+
67+
ghz_native_rewrite = SquinToNative().emit(ghz)
68+
AggressiveUnroll(ghz.dialects).fixpoint(ghz_native_rewrite)
69+
ghz_native_rewrite.print()
70+
71+
AggressiveUnroll(ghz_native.dialects).fixpoint(ghz_native)
72+
ghz_native.print()
73+
assert_methods(ghz_native_rewrite, ghz_native)

0 commit comments

Comments
 (0)