Skip to content

Commit 2468f16

Browse files
committed
adding rewrite
1 parent 0958f24 commit 2468f16

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from kirin import ir
2+
from kirin.rewrite import abc as rewrite_abc
3+
from kirin.dialects import py
4+
5+
from bloqade.squin.gate import stmts as gate_stmts
6+
7+
8+
class RewriteNonCliffordToU3(rewrite_abc.RewriteRule):
9+
"""Rewrite non-Clifford gates to U3 gates.
10+
11+
This rewrite rule transforms specific non-Clifford single-qubit gates
12+
into equivalent U3 gate representations. The following transformations are applied:
13+
- T gate (with adjoint attribute) to U3 gate with parameters (0, 0, ±π/4)
14+
- Rx gate to U3 gate with parameters (angle, -π/2, π/2)
15+
- Ry gate to U3 gate with parameters (angle, 0, 0)
16+
- Rz gate is U3 gate with parameters (0, 0, angle)
17+
18+
"""
19+
20+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
21+
rule = getattr(self, f"rewrite_{type(node).__name__}", self.default)
22+
23+
return rule(node)
24+
25+
def default(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
26+
return rewrite_abc.RewriteResult()
27+
28+
def rewrite_T(self, node: gate_stmts.T) -> rewrite_abc.RewriteResult:
29+
if node.adjoint:
30+
lam_value = -1.0 / 8.0
31+
else:
32+
lam_value = 1.0 / 8.0
33+
34+
(theta_stmt := py.Constant(0.0)).insert_before(node)
35+
(phi_stmt := py.Constant(0.0)).insert_before(node)
36+
(lam_stmt := py.Constant(lam_value)).insert_before(node)
37+
38+
node.replace_by(
39+
gate_stmts.U3(
40+
qubits=node.qubits,
41+
theta=theta_stmt.result,
42+
phi=phi_stmt.result,
43+
lam=lam_stmt.result,
44+
)
45+
)
46+
47+
return rewrite_abc.RewriteResult(has_done_something=True)
48+
49+
def rewrite_Rx(self, node: gate_stmts.Rx) -> rewrite_abc.RewriteResult:
50+
(phi_stmt := py.Constant(-0.25)).insert_before(node)
51+
(lam_stmt := py.Constant(0.25)).insert_before(node)
52+
53+
node.replace_by(
54+
gate_stmts.U3(
55+
qubits=node.qubits,
56+
theta=node.angle,
57+
phi=phi_stmt.result,
58+
lam=lam_stmt.result,
59+
)
60+
)
61+
62+
return rewrite_abc.RewriteResult(has_done_something=True)
63+
64+
def rewrite_Ry(self, node: gate_stmts.Ry) -> rewrite_abc.RewriteResult:
65+
(phi_stmt := py.Constant(0.0)).insert_before(node)
66+
(lam_stmt := py.Constant(0.0)).insert_before(node)
67+
68+
node.replace_by(
69+
gate_stmts.U3(
70+
qubits=node.qubits,
71+
theta=node.angle,
72+
phi=phi_stmt.result,
73+
lam=lam_stmt.result,
74+
)
75+
)
76+
77+
return rewrite_abc.RewriteResult(has_done_something=True)

0 commit comments

Comments
 (0)