Skip to content

Commit 12fe6c2

Browse files
kaihsinehua7365johnzl-777
authored andcommitted
rewrite rule for U3 to clifford (#335)
Co-authored-by: E Huang <[email protected]> Co-authored-by: John Long <[email protected]>
1 parent a19108c commit 12fe6c2

File tree

10 files changed

+677
-3
lines changed

10 files changed

+677
-3
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ repos:
1414
- id: isort
1515
name: isort (python)
1616
- repo: https://github.com/psf/black
17-
rev: 24.10.0
17+
rev: 25.1.0
1818
hooks:
1919
- id: black
2020
- repo: https://github.com/charliermarsh/ruff-pre-commit

src/bloqade/squin/op/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
shift as shift,
3131
spin_n as spin_n,
3232
spin_p as spin_p,
33+
sqrt_x as sqrt_x,
34+
sqrt_y as sqrt_y,
35+
sqrt_z as sqrt_z,
3336
adjoint as adjoint,
3437
control as control,
3538
identity as identity,

src/bloqade/squin/op/_wrapper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,18 @@ def y() -> types.Op: ...
6969
def z() -> types.Op: ...
7070

7171

72+
@wraps(stmts.SqrtX)
73+
def sqrt_x() -> types.Op: ...
74+
75+
76+
@wraps(stmts.SqrtY)
77+
def sqrt_y() -> types.Op: ...
78+
79+
80+
@wraps(stmts.S)
81+
def sqrt_z() -> types.Op: ...
82+
83+
7284
@wraps(stmts.H)
7385
def h() -> types.Op: ...
7486

src/bloqade/squin/op/stmts.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,12 @@ class Reset(PrimitiveOp):
142142

143143

144144
@statement
145-
class PauliOp(ConstantUnitary):
145+
class CliffordOp(ConstantUnitary):
146+
pass
147+
148+
149+
@statement
150+
class PauliOp(CliffordOp):
146151
pass
147152

148153

@@ -173,6 +178,19 @@ class Z(PauliOp):
173178
pass
174179

175180

181+
@statement(dialect=dialect)
182+
class SqrtX(ConstantUnitary):
183+
pass
184+
185+
186+
@statement(dialect=dialect)
187+
class SqrtY(ConstantUnitary):
188+
pass
189+
190+
191+
# NOTE no SqrtZ since its equal to S
192+
193+
176194
@statement(dialect=dialect)
177195
class H(ConstantUnitary):
178196
pass
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# create rewrite rule name SquinMeasureToStim using kirin
2+
import math
3+
from typing import List, Tuple, Callable
4+
5+
import numpy as np
6+
from kirin import ir
7+
from kirin.dialects import py
8+
from kirin.rewrite.abc import RewriteRule, RewriteResult
9+
10+
from bloqade.squin import op, qubit
11+
12+
13+
def sdag() -> list[ir.Statement]:
14+
return [_op := op.stmts.S(), op.stmts.Adjoint(op=_op.result, is_unitary=True)]
15+
16+
17+
# (theta, phi, lam)
18+
U3_HALF_PI_ANGLE_TO_GATES: dict[
19+
tuple[int, int, int], Callable[[], Tuple[List[ir.Statement], ...]]
20+
] = {
21+
(0, 0, 0): lambda: ([op.stmts.Identity(sites=1)],),
22+
(0, 0, 1): lambda: ([op.stmts.S()],),
23+
(0, 0, 2): lambda: ([op.stmts.Z()],),
24+
(0, 0, 3): lambda: (sdag(),),
25+
(1, 0, 0): lambda: ([op.stmts.SqrtY()],),
26+
(1, 0, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()]),
27+
(1, 0, 2): lambda: ([op.stmts.H()],),
28+
(1, 0, 3): lambda: (sdag(), [op.stmts.SqrtY()]),
29+
(1, 1, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.S()]),
30+
(1, 1, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]),
31+
(1, 1, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]),
32+
(1, 1, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]),
33+
(1, 2, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.Z()]),
34+
(1, 2, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]),
35+
(1, 2, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]),
36+
(1, 2, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]),
37+
(1, 3, 0): lambda: ([op.stmts.SqrtY()], sdag()),
38+
(1, 3, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()),
39+
(1, 3, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()),
40+
(1, 3, 3): lambda: (sdag(), [op.stmts.SqrtY()], sdag()),
41+
(2, 0, 0): lambda: ([op.stmts.Y()],),
42+
(2, 0, 1): lambda: ([op.stmts.S()], [op.stmts.Y()]),
43+
(2, 0, 2): lambda: ([op.stmts.Z()], [op.stmts.Y()]),
44+
(2, 0, 3): lambda: (sdag(), [op.stmts.Y()]),
45+
}
46+
47+
48+
def equivalent_u3_para(
49+
theta_half_pi: int, phi_half_pi: int, lam_half_pi: int
50+
) -> tuple[int, int, int]:
51+
"""
52+
1. Assume all three angles are in the range [0, 4].
53+
2. U3(theta, phi, lam) = -U3(2pi-theta, phi+pi, lam+pi).
54+
"""
55+
return ((4 - theta_half_pi) % 4, (phi_half_pi + 2) % 4, (lam_half_pi + 2) % 4)
56+
57+
58+
class SquinU3ToClifford(RewriteRule):
59+
"""
60+
Rewrite squin U3 statements to clifford when possible.
61+
"""
62+
63+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
64+
if isinstance(node, (qubit.Apply, qubit.Broadcast)):
65+
return self.rewrite_ApplyOrBroadcast_onU3(node)
66+
else:
67+
return RewriteResult()
68+
69+
def get_constant(self, node: ir.SSAValue) -> float | None:
70+
if isinstance(node.owner, py.Constant):
71+
# node.value is a PyAttr, need to get the wrapped value out
72+
return node.owner.value.unwrap()
73+
else:
74+
return None
75+
76+
def resolve_angle(self, angle: float) -> int | None:
77+
"""
78+
Normalize the angle to be in the range [0, 2π).
79+
"""
80+
# convert to 0.0~1.0, in unit of pi/2
81+
angle_half_pi = angle / math.pi * 2.0
82+
83+
mod = angle_half_pi % 1.0
84+
if not (np.isclose(mod, 0.0) or np.isclose(mod, 1.0)):
85+
return None
86+
87+
else:
88+
return round((angle / math.tau) % 1 * 4) % 4
89+
90+
def rewrite_ApplyOrBroadcast_onU3(
91+
self, node: qubit.Apply | qubit.Broadcast
92+
) -> RewriteResult:
93+
"""
94+
Rewrite Apply and Broadcast nodes to their clifford equivalent statements.
95+
"""
96+
if not isinstance(node.operator.owner, op.stmts.U3):
97+
return RewriteResult()
98+
99+
gates = self.decompose_U3_gates(node.operator.owner)
100+
101+
if len(gates) == 0:
102+
return RewriteResult()
103+
104+
for stmt_list in gates:
105+
for gate_stmt in stmt_list[:-1]:
106+
gate_stmt.insert_before(node)
107+
108+
oper = stmt_list[-1]
109+
oper.insert_before(node)
110+
new_node = node.__class__(operator=oper.result, qubits=node.qubits)
111+
new_node.insert_before(node)
112+
113+
node.delete()
114+
115+
# rewrite U3 to clifford gates
116+
return RewriteResult(has_done_something=True)
117+
118+
def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...]:
119+
"""
120+
Rewrite U3 statements to clifford gates if possible.
121+
"""
122+
theta = self.get_constant(node.theta)
123+
phi = self.get_constant(node.phi)
124+
lam = self.get_constant(node.lam)
125+
126+
if theta is None or phi is None or lam is None:
127+
return ()
128+
129+
theta_half_pi: int | None = self.resolve_angle(theta)
130+
phi_half_pi: int | None = self.resolve_angle(phi)
131+
lam_half_pi: int | None = self.resolve_angle(lam)
132+
133+
if theta_half_pi is None or phi_half_pi is None or lam_half_pi is None:
134+
return ()
135+
136+
angles_key = (theta_half_pi, phi_half_pi, lam_half_pi)
137+
if angles_key not in U3_HALF_PI_ANGLE_TO_GATES:
138+
angles_key = equivalent_u3_para(*angles_key)
139+
if angles_key not in U3_HALF_PI_ANGLE_TO_GATES:
140+
return ()
141+
142+
gates_stmts = U3_HALF_PI_ANGLE_TO_GATES.get(angles_key)
143+
144+
# no consistent gates, then:
145+
assert (
146+
gates_stmts is not None
147+
), "internal error, U3 gates not found for angles: {}".format(angles_key)
148+
149+
return gates_stmts()

src/bloqade/stim/dialects/gate/stmts/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from .pp import SPP as SPP
2+
from .base import (
3+
Gate as Gate,
4+
TwoQubitGate as TwoQubitGate,
5+
SingleQubitGate as SingleQubitGate,
6+
ControlledTwoQubitGate as ControlledTwoQubitGate,
7+
)
28
from .control_2q import CX as CX, CY as CY, CZ as CZ
39
from .clifford_1q import (
410
H as H,

src/bloqade/stim/rewrite/qubit_to_stim.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from bloqade.squin import op, qubit
55
from bloqade.squin.rewrite import AddressAttribute
6+
from bloqade.stim.dialects import gate
67
from bloqade.stim.rewrite.util import (
78
SQUIN_STIM_GATE_MAPPING,
89
rewrite_Control,
@@ -34,6 +35,15 @@ def rewrite_Apply_and_Broadcast(
3435
if isinstance(applied_op, op.stmts.Control):
3536
return rewrite_Control(stmt)
3637

38+
# check if its adjoint, assume its canonicalized so no nested adjoints.
39+
is_conj = False
40+
if isinstance(applied_op, op.stmts.Adjoint):
41+
if not applied_op.is_unitary:
42+
return RewriteResult()
43+
44+
is_conj = True
45+
applied_op = applied_op.op.owner
46+
3747
# need to handle Control through separate means
3848
# but we can handle X, Y, Z, H, and S here just fine
3949
stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
@@ -54,7 +64,10 @@ def rewrite_Apply_and_Broadcast(
5464
if qubit_idx_ssas is None:
5565
return RewriteResult()
5666

57-
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
67+
if isinstance(stim_1q_op, gate.stmts.Gate):
68+
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas), dagger=is_conj)
69+
else:
70+
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
5871
stmt.replace_by(stim_1q_stmt)
5972

6073
return RewriteResult(has_done_something=True)

src/bloqade/stim/rewrite/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
op.stmts.Z: gate.Z,
1414
op.stmts.H: gate.H,
1515
op.stmts.S: gate.S,
16+
op.stmts.SqrtX: gate.SqrtX,
17+
op.stmts.SqrtY: gate.SqrtY,
1618
op.stmts.Identity: gate.Identity,
1719
op.stmts.Reset: collapse.RZ,
1820
}

0 commit comments

Comments
 (0)