Skip to content

Commit cf2048d

Browse files
committed
expanding tests to test all rotation-like gates
1 parent 26de281 commit cf2048d

File tree

3 files changed

+42
-6
lines changed

3 files changed

+42
-6
lines changed

src/bloqade/native/stdlib/broadcast.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,15 @@ def shift(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
202202
rz(angle / 2.0, qubits)
203203

204204

205+
@kernel
206+
def _u3_turns(
207+
theta: float, phi: float, lam: float, qubits: ilist.IList[qubit.Qubit, Any]
208+
):
209+
_rz_turns(lam, qubits)
210+
_ry_turns(theta, qubits)
211+
_rz_turns(phi, qubits)
212+
213+
205214
@kernel
206215
def u3(theta: float, phi: float, lam: float, qubits: ilist.IList[qubit.Qubit, Any]):
207216
"""Apply the U3 gate on a group of qubits.

src/bloqade/native/upstream/squin2native.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class GateRule(RewriteRule):
2626
stmts.CX: (broadcast.cx,),
2727
stmts.CY: (broadcast.cy,),
2828
stmts.CZ: (broadcast.cz,),
29-
stmts.U3: (broadcast.u3,),
29+
stmts.U3: (broadcast._u3_turns,),
3030
}
3131

3232
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:

test/native/upstream/test_squin2native.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,24 +51,51 @@ def main():
5151
assert np.allclose(old_sv, new_sv)
5252

5353

54-
def test_pipeline():
54+
@pytest.mark.parametrize(
55+
("squin_gate", "native_gate"),
56+
[
57+
(squin.rz, native.rz),
58+
(squin.rx, native.rx),
59+
(squin.ry, native.ry),
60+
],
61+
)
62+
def test_pipeline(squin_gate, native_gate):
5563

5664
@squin.kernel
5765
def ghz(angle: float):
5866
qubits = squin.qalloc(1)
59-
squin.rz(angle, qubits[0])
67+
squin_gate(angle, qubits[0])
6068

6169
@native.kernel
6270
def ghz_native(angle: float):
6371
qubits = squin.qalloc(1)
64-
native.rz(angle, qubits[0])
72+
native_gate(angle, qubits[0])
73+
74+
ghz_native_rewrite = SquinToNative().emit(ghz)
75+
AggressiveUnroll(ghz.dialects).fixpoint(ghz_native_rewrite)
76+
77+
AggressiveUnroll(ghz_native.dialects).fixpoint(ghz_native)
78+
test_utils.assert_nodes(
79+
ghz_native_rewrite.callable_region, ghz_native.callable_region
80+
)
81+
82+
83+
def test_pipeline_u3():
84+
85+
@squin.kernel
86+
def ghz(theta: float, phi: float, lam: float):
87+
qubits = squin.qalloc(1)
88+
squin.u3(theta, phi, lam, qubits[0])
89+
90+
@native.kernel
91+
def ghz_native(theta: float, phi: float, lam: float):
92+
qubits = squin.qalloc(1)
93+
native.u3(theta, phi, lam, qubits[0])
6594

6695
ghz_native_rewrite = SquinToNative().emit(ghz)
6796
AggressiveUnroll(ghz.dialects).fixpoint(ghz_native_rewrite)
68-
ghz_native_rewrite.print()
6997

7098
AggressiveUnroll(ghz_native.dialects).fixpoint(ghz_native)
71-
ghz_native.print()
7299
test_utils.assert_nodes(
73100
ghz_native_rewrite.callable_region, ghz_native.callable_region
74101
)

0 commit comments

Comments
 (0)