Skip to content

Commit ae462fc

Browse files
committed
fixing bugs + adding tests
1 parent 2468f16 commit ae462fc

File tree

2 files changed

+171
-0
lines changed

2 files changed

+171
-0
lines changed

src/bloqade/squin/rewrite/non_clifford_to_U3.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,22 @@ class RewriteNonCliffordToU3(rewrite_abc.RewriteRule):
1515
- Ry gate to U3 gate with parameters (angle, 0, 0)
1616
- Rz gate is U3 gate with parameters (0, 0, angle)
1717
18+
This rewrite should be paired with `U3ToClifford` to canonicalize the circuit.
19+
1820
"""
1921

2022
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
23+
if not isinstance(
24+
node,
25+
(
26+
gate_stmts.T,
27+
gate_stmts.Rx,
28+
gate_stmts.Ry,
29+
gate_stmts.Rz,
30+
),
31+
):
32+
return rewrite_abc.RewriteResult()
33+
2134
rule = getattr(self, f"rewrite_{type(node).__name__}", self.default)
2235

2336
return rule(node)
@@ -75,3 +88,18 @@ def rewrite_Ry(self, node: gate_stmts.Ry) -> rewrite_abc.RewriteResult:
7588
)
7689

7790
return rewrite_abc.RewriteResult(has_done_something=True)
91+
92+
def rewrite_Rz(self, node: gate_stmts.Rz) -> rewrite_abc.RewriteResult:
93+
(theta_stmt := py.Constant(0.0)).insert_before(node)
94+
(phi_stmt := py.Constant(0.0)).insert_before(node)
95+
96+
node.replace_by(
97+
gate_stmts.U3(
98+
qubits=node.qubits,
99+
theta=theta_stmt.result,
100+
phi=phi_stmt.result,
101+
lam=node.angle,
102+
)
103+
)
104+
105+
return rewrite_abc.RewriteResult(has_done_something=True)
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from kirin import ir, rewrite
2+
from kirin.dialects import py
3+
4+
from bloqade.squin.gate import stmts as gate_stmts
5+
from bloqade.test_utils import assert_nodes
6+
from bloqade.squin.rewrite.non_clifford_to_U3 import RewriteNonCliffordToU3
7+
8+
9+
def test_rewrite_T():
10+
test_qubits = ir.TestValue()
11+
test_block = ir.Block([gate_stmts.T(qubits=test_qubits, adjoint=False)])
12+
13+
expected_block = ir.Block(
14+
[
15+
theta := py.Constant(0.0),
16+
phi := py.Constant(0.0),
17+
lam := py.Constant(1.0 / 8.0),
18+
gate_stmts.U3(
19+
qubits=test_qubits,
20+
theta=theta.result,
21+
phi=phi.result,
22+
lam=lam.result,
23+
),
24+
]
25+
)
26+
27+
rule = rewrite.Walk(RewriteNonCliffordToU3())
28+
rule.rewrite(test_block)
29+
30+
assert_nodes(test_block, expected_block)
31+
32+
33+
def test_rewrite_Tadj():
34+
test_qubits = ir.TestValue()
35+
test_block = ir.Block([gate_stmts.T(qubits=test_qubits, adjoint=True)])
36+
37+
expected_block = ir.Block(
38+
[
39+
theta := py.Constant(0.0),
40+
phi := py.Constant(0.0),
41+
lam := py.Constant(-1.0 / 8.0),
42+
gate_stmts.U3(
43+
qubits=test_qubits,
44+
theta=theta.result,
45+
phi=phi.result,
46+
lam=lam.result,
47+
),
48+
]
49+
)
50+
51+
rule = rewrite.Walk(RewriteNonCliffordToU3())
52+
rule.rewrite(test_block)
53+
54+
assert_nodes(test_block, expected_block)
55+
56+
57+
def test_rewrite_Ry():
58+
test_qubits = ir.TestValue()
59+
angle = ir.TestValue()
60+
test_block = ir.Block([gate_stmts.Ry(qubits=test_qubits, angle=angle)])
61+
62+
expected_block = ir.Block(
63+
[
64+
phi := py.Constant(0.0),
65+
lam := py.Constant(0.0),
66+
gate_stmts.U3(
67+
qubits=test_qubits,
68+
theta=angle,
69+
phi=phi.result,
70+
lam=lam.result,
71+
),
72+
]
73+
)
74+
75+
rule = rewrite.Walk(RewriteNonCliffordToU3())
76+
rule.rewrite(test_block)
77+
78+
assert_nodes(test_block, expected_block)
79+
80+
81+
def test_rewrite_Rz():
82+
test_qubits = ir.TestValue()
83+
angle = ir.TestValue()
84+
test_block = ir.Block([gate_stmts.Rz(qubits=test_qubits, angle=angle)])
85+
86+
expected_block = ir.Block(
87+
[
88+
theta := py.Constant(0.0),
89+
phi := py.Constant(0.0),
90+
gate_stmts.U3(
91+
qubits=test_qubits,
92+
theta=theta.result,
93+
phi=phi.result,
94+
lam=angle,
95+
),
96+
]
97+
)
98+
99+
rule = rewrite.Walk(RewriteNonCliffordToU3())
100+
rule.rewrite(test_block)
101+
102+
assert_nodes(test_block, expected_block)
103+
104+
105+
def test_rewrite_Rx():
106+
test_qubits = ir.TestValue()
107+
angle = ir.TestValue()
108+
test_block = ir.Block([gate_stmts.Rx(qubits=test_qubits, angle=angle)])
109+
110+
expected_block = ir.Block(
111+
[
112+
phi := py.Constant(-0.25),
113+
lam := py.Constant(0.25),
114+
gate_stmts.U3(
115+
qubits=test_qubits,
116+
theta=angle,
117+
phi=phi.result,
118+
lam=lam.result,
119+
),
120+
]
121+
)
122+
123+
rule = rewrite.Walk(RewriteNonCliffordToU3())
124+
rule.rewrite(test_block)
125+
126+
assert_nodes(test_block, expected_block)
127+
128+
129+
def test_no_op():
130+
test_qubits = ir.TestValue()
131+
angle = ir.TestValue()
132+
test_block = ir.Block(
133+
[gate_stmts.U3(qubits=test_qubits, theta=angle, phi=angle, lam=angle)]
134+
)
135+
136+
expected_block = ir.Block(
137+
[gate_stmts.U3(qubits=test_qubits, theta=angle, phi=angle, lam=angle)]
138+
)
139+
140+
rule = rewrite.Walk(RewriteNonCliffordToU3())
141+
rule.rewrite(test_block)
142+
143+
assert_nodes(test_block, expected_block)

0 commit comments

Comments
 (0)