Skip to content

Commit 5d17f26

Browse files
authored
Merge branch 'main' into tcochran-measure_and_reset
2 parents 03cd75a + 8fa9d04 commit 5d17f26

File tree

4 files changed

+91
-12
lines changed

4 files changed

+91
-12
lines changed

src/bloqade/native/stdlib/broadcast.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,30 @@ def _radian_to_turn(angle: float) -> float:
2121
return angle / (2 * math.pi)
2222

2323

24+
@kernel
25+
def _rx_turns(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
26+
native.r(0.0, angle, qubits)
27+
28+
29+
@kernel
30+
def _ry_turns(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
31+
native.r(0.25, angle, qubits)
32+
33+
34+
@kernel
35+
def _rz_turns(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
36+
native.rz(angle, qubits)
37+
38+
39+
@kernel
40+
def _u3_turns(
41+
theta: float, phi: float, lam: float, qubits: ilist.IList[qubit.Qubit, Any]
42+
):
43+
_rz_turns(lam, qubits)
44+
_ry_turns(theta, qubits)
45+
_rz_turns(phi, qubits)
46+
47+
2448
@kernel
2549
def rx(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
2650
"""Apply an RX rotation gate on a group of qubits.
@@ -29,7 +53,7 @@ def rx(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
2953
angle (float): Rotation angle in radians.
3054
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
3155
"""
32-
native.r(0.0, _radian_to_turn(angle), qubits)
56+
_rx_turns(_radian_to_turn(angle), qubits)
3357

3458

3559
@kernel
@@ -70,7 +94,7 @@ def ry(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
7094
angle (float): Rotation angle in radians.
7195
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
7296
"""
73-
native.r(0.25, _radian_to_turn(angle), qubits)
97+
_ry_turns(_radian_to_turn(angle), qubits)
7498

7599

76100
@kernel
@@ -111,7 +135,7 @@ def rz(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
111135
angle (float): Rotation angle in radians.
112136
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
113137
"""
114-
native.rz(_radian_to_turn(angle), qubits)
138+
_rz_turns(_radian_to_turn(angle), qubits)
115139

116140

117141
@kernel
@@ -201,9 +225,9 @@ def u3(theta: float, phi: float, lam: float, qubits: ilist.IList[qubit.Qubit, An
201225
lam (float): Z rotations in decomposition (radians).
202226
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
203227
"""
204-
rz(lam, qubits)
205-
ry(theta, qubits)
206-
rz(phi, qubits)
228+
_u3_turns(
229+
_radian_to_turn(theta), _radian_to_turn(phi), _radian_to_turn(lam), qubits
230+
)
207231

208232

209233
N = TypeVar("N")

src/bloqade/native/upstream/squin2native.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ class GateRule(RewriteRule):
2020
stmts.T: (broadcast.t, broadcast.t_adj),
2121
stmts.SqrtX: (broadcast.sqrt_x, broadcast.sqrt_x_adj),
2222
stmts.SqrtY: (broadcast.sqrt_y, broadcast.sqrt_y_adj),
23-
stmts.Rx: (broadcast.rx,),
24-
stmts.Ry: (broadcast.ry,),
25-
stmts.Rz: (broadcast.rz,),
23+
stmts.Rx: (broadcast._rx_turns,),
24+
stmts.Ry: (broadcast._ry_turns,),
25+
stmts.Rz: (broadcast._rz_turns,),
2626
stmts.CX: (broadcast.cx,),
2727
stmts.CY: (broadcast.cy,),
2828
stmts.CZ: (broadcast.cz,),
29-
stmts.U3: (broadcast.u3,),
29+
stmts.U3: (broadcast._u3_turns,),
3030
}
3131

3232
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:

src/bloqade/rewrite/passes/callgraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
9898
mt_map = {}
9999

100100
cg = CallGraph(mt)
101-
102101
all_methods = set(cg.edges.keys())
102+
all_methods.add(mt)
103103
for original_mt in all_methods:
104104
if original_mt is mt:
105105
new_mt = original_mt

test/native/upstream/test_squin2native.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import pytest
33
from kirin.analysis import callgraph
44

5-
from bloqade import squin
5+
from bloqade import squin, native, test_utils
66
from bloqade.squin import gate
77
from bloqade.pyqrack import StackMemorySimulator
8+
from bloqade.rewrite.passes import AggressiveUnroll
89
from bloqade.native.dialects import gate as native_gate
910
from bloqade.native.upstream import GateRule, SquinToNative
1011

@@ -48,3 +49,57 @@ def main():
4849
new_sv /= new_sv[imax := np.abs(new_sv).argmax()] / np.abs(new_sv[imax])
4950

5051
assert np.allclose(old_sv, new_sv)
52+
53+
54+
@pytest.mark.parametrize(
55+
("squin_gate", "native_gate"),
56+
[
57+
(squin.rz, native.rz),
58+
(squin.rx, native.rx),
59+
(squin.ry, native.ry),
60+
],
61+
)
62+
def test_pipeline(squin_gate, native_gate):
63+
64+
@squin.kernel
65+
def ghz(angle: float):
66+
qubits = squin.qalloc(1)
67+
squin_gate(angle, qubits[0])
68+
69+
@native.kernel
70+
def ghz_native(angle: float):
71+
qubits = squin.qalloc(1)
72+
native_gate(angle, qubits[0])
73+
74+
ghz_native_rewrite = SquinToNative().emit(ghz)
75+
AggressiveUnroll(ghz.dialects).fixpoint(ghz_native_rewrite)
76+
77+
AggressiveUnroll(ghz_native.dialects).fixpoint(ghz_native)
78+
test_utils.assert_nodes(
79+
ghz_native_rewrite.callable_region.blocks[0],
80+
ghz_native.callable_region.blocks[0],
81+
)
82+
83+
84+
def test_pipeline_u3():
85+
86+
@squin.kernel
87+
def ghz(theta: float, phi: float, lam: float):
88+
qubits = squin.qalloc(1)
89+
squin.u3(theta, phi, lam, qubits[0])
90+
91+
@native.kernel
92+
def ghz_native(theta: float, phi: float, lam: float):
93+
qubits = squin.qalloc(1)
94+
native.u3(theta, phi, lam, qubits[0])
95+
96+
# unroll first to check that gete rewrites happen in ghz body
97+
AggressiveUnroll(ghz.dialects).fixpoint(ghz)
98+
ghz_native_rewrite = SquinToNative().emit(ghz)
99+
AggressiveUnroll(ghz.dialects).fixpoint(ghz_native_rewrite)
100+
101+
AggressiveUnroll(ghz_native.dialects).fixpoint(ghz_native)
102+
test_utils.assert_nodes(
103+
ghz_native_rewrite.callable_region.blocks[0],
104+
ghz_native.callable_region.blocks[0],
105+
)

0 commit comments

Comments
 (0)