Skip to content

Commit 1d1c980

Browse files
weinbe58Roger-luo
andauthored
Adding custom rewrite for ccx gate (#245)
closes #241 I have also added an option to `ParallelToUop` to decompose to the rydberg gateset before just to make sure the rewrite rule is applied properly. --------- Co-authored-by: Xiu-zhe (Roger) Luo <[email protected]>
1 parent e920ad6 commit 1d1c980

File tree

4 files changed

+101
-4
lines changed

4 files changed

+101
-4
lines changed

src/bloqade/qasm2/passes/parallel.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
RaiseRegisterRule,
2828
UOpToParallelRule,
2929
SimpleOptimalMergePolicy,
30+
RydbergGateSetRewriteRule,
3031
)
3132
from bloqade.squin.analysis import schedule
3233

@@ -135,6 +136,7 @@ def test():
135136
"""
136137

137138
merge_policy_type: Type[MergePolicyABC] = SimpleOptimalMergePolicy
139+
rewrite_to_native_first: bool = False
138140
constprop: const.Propagate = field(init=False)
139141

140142
def __post_init__(self):
@@ -147,6 +149,13 @@ def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
147149
if not result.has_done_something:
148150
return result
149151

152+
if self.rewrite_to_native_first:
153+
result = (
154+
Fixpoint(Walk(RydbergGateSetRewriteRule(self.dialects)))
155+
.rewrite(mt.code)
156+
.join(result)
157+
)
158+
150159
frame, _ = self.constprop.run_analysis(mt)
151160
result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
152161

src/bloqade/qasm2/rewrite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
GlobalToParallelRule as GlobalToParallelRule,
44
)
55
from .register import RaiseRegisterRule as RaiseRegisterRule
6+
from .native_gates import RydbergGateSetRewriteRule as RydbergGateSetRewriteRule
67
from .parallel_to_uop import ParallelToUOpRule as ParallelToUOpRule
78
from .uop_to_parallel import (
89
MergePolicyABC as MergePolicyABC,

src/bloqade/qasm2/rewrite/native_gates.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,53 @@ def rewrite_sx(self, node: uop.SX) -> abc.RewriteResult:
173173
cirq.XPowGate(exponent=0.5).on(self.cached_qubits[0]), node
174174
)
175175

176+
def rewrite_ccx(self, node: uop.CCX) -> abc.RewriteResult:
177+
# from https://algassert.com/quirk#circuit=%7B%22cols%22:%5B%5B%22QFT3%22%5D,%5B%22inputA3%22,1,1,%22+=A3%22%5D,%5B1,1,1,%22%E2%80%A2%22,%22%E2%80%A2%22,%22X%22%5D,%5B1,1,1,%22%E2%80%A6%22,%22%E2%80%A6%22,%22%E2%80%A6%22%5D,%5B1,1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,1,%22X%5E-%C2%BC%22%5D,%5B1,1,1,%22%E2%80%A2%22,1,%22Z%22%5D,%5B1,1,1,1,1,%22X%5E%C2%BC%22%5D,%5B1,1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,1,%22X%5E-%C2%BC%22%5D,%5B1,1,1,%22Z%5E%C2%BC%22,%22Z%5E%C2%BC%22%5D,%5B1,1,1,1,%22H%22%5D,%5B1,1,1,%22%E2%80%A2%22,1,%22Z%22%5D,%5B1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,%22X%5E-%C2%BC%22,%22X%5E%C2%BC%22%5D,%5B1,1,1,%22%E2%80%A2%22,%22Z%22%5D,%5B1,1,1,1,%22H%22%5D%5D%7D
178+
179+
# x^(1/4)
180+
lam1, theta1, phi1 = map(
181+
self.const_float,
182+
map(around, (1.5707963267948966, 0.7853981633974483, -1.5707963267948966)),
183+
)
184+
lam1.insert_before(node)
185+
theta1.insert_before(node)
186+
phi1.insert_before(node)
187+
188+
lam1 = lam1.result
189+
theta1 = theta1.result
190+
phi1 = phi1.result
191+
192+
# x^(-1/4)
193+
lam2, theta2, phi2 = map(
194+
self.const_float,
195+
map(around, (4.71238898038469, 0.7853981633974483, 1.5707963267948966)),
196+
)
197+
lam2.insert_before(node)
198+
theta2.insert_before(node)
199+
phi2.insert_before(node)
200+
lam2 = lam2.result
201+
theta2 = theta2.result
202+
phi2 = phi2.result
203+
204+
uop.CZ(ctrl=node.ctrl1, qarg=node.qarg).insert_before(node)
205+
uop.UGate(node.qarg, theta2, phi2, lam2).insert_before(node)
206+
uop.CZ(ctrl=node.ctrl2, qarg=node.qarg).insert_before(node)
207+
uop.UGate(node.qarg, theta1, phi1, lam1).insert_before(node)
208+
uop.CZ(ctrl=node.ctrl1, qarg=node.qarg).insert_before(node)
209+
uop.UGate(node.qarg, theta2, phi2, lam2).insert_before(node)
210+
uop.T(node.ctrl1).insert_before(node)
211+
uop.T(node.ctrl2).insert_before(node)
212+
uop.H(node.ctrl1).insert_before(node)
213+
uop.CZ(ctrl=node.ctrl2, qarg=node.qarg).insert_before(node)
214+
uop.CZ(ctrl=node.ctrl2, qarg=node.ctrl1).insert_before(node)
215+
uop.UGate(node.ctrl1, theta2, phi2, lam2).insert_before(node)
216+
uop.UGate(node.qarg, theta2, phi2, lam2).insert_before(node)
217+
uop.CZ(ctrl=node.ctrl2, qarg=node.ctrl1).insert_before(node)
218+
uop.H(node.ctrl1).insert_before(node)
219+
node.delete() # delete the original CCX gate
220+
221+
return abc.RewriteResult(has_done_something=True)
222+
176223
def rewrite_sxdg(self, node: uop.SXdag) -> abc.RewriteResult:
177224
return self._rewrite_1q_gates(
178225
cirq.XPowGate(exponent=-0.5).on(self.cached_qubits[0]), node
@@ -394,9 +441,12 @@ def _rewrite_1q_gates(
394441
new_gate_stmts = self._generate_1q_gate_stmts(cirq_gate, node.qarg)
395442
return self._rewrite_gate_stmts(new_gate_stmts, node)
396443

397-
def _generate_2q_ctrl_gate_stmts(
444+
def _generate_multi_ctrl_gate_stmts(
398445
self, cirq_gate: cirq.Operation, qubits_ssa: List[ir.SSAValue]
399446
) -> list[ir.Statement]:
447+
qubit_to_ssa_map = {
448+
q: ssa for q, ssa in zip(self.cached_qubits[: len(qubits_ssa)], qubits_ssa)
449+
}
400450
target_gates = self.gateset.decompose_to_target_gateset(cirq_gate, 0)
401451
new_stmts = []
402452
for new_gate in target_gates:
@@ -412,26 +462,39 @@ def _generate_2q_ctrl_gate_stmts(
412462
new_stmts.append(phi2_stmt)
413463
new_stmts.append(
414464
uop.UGate(
415-
qarg=qubits_ssa[new_gate.qubits[0].x],
465+
qarg=qubit_to_ssa_map[new_gate.qubits[0]],
416466
theta=phi0_stmt.result,
417467
phi=phi1_stmt.result,
418468
lam=phi2_stmt.result,
419469
)
420470
)
421471
else:
422472
# 2q
423-
new_stmts.append(uop.CZ(ctrl=qubits_ssa[0], qarg=qubits_ssa[1]))
473+
new_stmts.append(
474+
uop.CZ(
475+
ctrl=qubit_to_ssa_map[new_gate.qubits[0]],
476+
qarg=qubit_to_ssa_map[new_gate.qubits[1]],
477+
)
478+
)
424479

425480
return new_stmts
426481

427482
def _rewrite_2q_ctrl_gates(
428483
self, cirq_gate: cirq.Operation, node: uop.TwoQubitCtrlGate
429484
) -> abc.RewriteResult:
430-
new_gate_stmts = self._generate_2q_ctrl_gate_stmts(
485+
new_gate_stmts = self._generate_multi_ctrl_gate_stmts(
431486
cirq_gate, [node.ctrl, node.qarg]
432487
)
433488
return self._rewrite_gate_stmts(new_gate_stmts, node)
434489

490+
def _rewrite_3q_ctrl_gates(
491+
self, cirq_gate: cirq.Operation, node: uop.CCX
492+
) -> abc.RewriteResult:
493+
new_gate_stmts = self._generate_multi_ctrl_gate_stmts(
494+
cirq_gate, [node.ctrl1, node.ctrl2, node.qarg]
495+
)
496+
return self._rewrite_gate_stmts(new_gate_stmts, node)
497+
435498
def _rewrite_gate_stmts(
436499
self, new_gate_stmts: list[ir.Statement], node: ir.Statement
437500
):

test/qasm2/test_native.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import textwrap
33

44
import cirq
5+
import numpy as np
56
import cirq.testing
67
import cirq.contrib.qasm_import as qasm_import
78
import cirq.circuits.qasm_output as qasm_output
89
from pytest import mark
910
from kirin.rewrite import walk
1011

1112
from bloqade import qasm2
13+
from bloqade.pyqrack import DynamicMemorySimulator
1214
from bloqade.qasm2.rewrite.native_gates import (
1315
RydbergGateSetRewriteRule,
1416
one_qubit_gate_to_u3_angles,
@@ -153,3 +155,25 @@ def kernel():
153155

154156
# simple-stupid test to see if the rewrite injected a bunch of new lines
155157
assert new_qasm2.count("\n") > prog.count("\n")
158+
159+
160+
def test_ccx_rewrite():
161+
162+
@qasm2.extended
163+
def main():
164+
q = qasm2.qreg(3)
165+
qasm2.ccx(q[0], q[1], q[2])
166+
167+
return q
168+
169+
main2 = main.similar()
170+
171+
walk.Walk(RydbergGateSetRewriteRule(main.dialects)).rewrite(main.code)
172+
173+
sim = DynamicMemorySimulator()
174+
175+
state = sim.state_vector(main)
176+
state2 = sim.state_vector(main2)
177+
assert (
178+
np.abs(np.vdot(state, state2)) - 1 < 1e-6
179+
) # Should be close to 1 if the states are equal

0 commit comments

Comments
 (0)