Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions src/bloqade/squin/rewrite/non_clifford_to_U3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from kirin import ir
from kirin.rewrite import abc as rewrite_abc
from kirin.dialects import py

from bloqade.squin.gate import stmts as gate_stmts


class RewriteNonCliffordToU3(rewrite_abc.RewriteRule):
"""Rewrite non-Clifford gates to U3 gates.

This rewrite rule transforms specific non-Clifford single-qubit gates
into equivalent U3 gate representations. The following transformations are applied:
- T gate (with adjoint attribute) to U3 gate with parameters (0, 0, ±π/4)
- Rx gate to U3 gate with parameters (angle, -π/2, π/2)
- Ry gate to U3 gate with parameters (angle, 0, 0)
- Rz gate is U3 gate with parameters (0, 0, angle)

This rewrite should be paired with `U3ToClifford` to canonicalize the circuit.

"""

def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
if not isinstance(
node,
(
gate_stmts.T,
gate_stmts.Rx,
gate_stmts.Ry,
gate_stmts.Rz,
),
):
return rewrite_abc.RewriteResult()

rule = getattr(self, f"rewrite_{type(node).__name__}", self.default)

return rule(node)

def default(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
return rewrite_abc.RewriteResult()

def rewrite_T(self, node: gate_stmts.T) -> rewrite_abc.RewriteResult:
if node.adjoint:
lam_value = -1.0 / 8.0
else:
lam_value = 1.0 / 8.0

(theta_stmt := py.Constant(0.0)).insert_before(node)
(phi_stmt := py.Constant(0.0)).insert_before(node)
(lam_stmt := py.Constant(lam_value)).insert_before(node)

node.replace_by(
gate_stmts.U3(
qubits=node.qubits,
theta=theta_stmt.result,
phi=phi_stmt.result,
lam=lam_stmt.result,
)
)

return rewrite_abc.RewriteResult(has_done_something=True)

def rewrite_Rx(self, node: gate_stmts.Rx) -> rewrite_abc.RewriteResult:
(phi_stmt := py.Constant(-0.25)).insert_before(node)
(lam_stmt := py.Constant(0.25)).insert_before(node)

node.replace_by(
gate_stmts.U3(
qubits=node.qubits,
theta=node.angle,
phi=phi_stmt.result,
lam=lam_stmt.result,
)
)

return rewrite_abc.RewriteResult(has_done_something=True)

def rewrite_Ry(self, node: gate_stmts.Ry) -> rewrite_abc.RewriteResult:
(phi_stmt := py.Constant(0.0)).insert_before(node)
(lam_stmt := py.Constant(0.0)).insert_before(node)

node.replace_by(
gate_stmts.U3(
qubits=node.qubits,
theta=node.angle,
phi=phi_stmt.result,
lam=lam_stmt.result,
)
)

return rewrite_abc.RewriteResult(has_done_something=True)

def rewrite_Rz(self, node: gate_stmts.Rz) -> rewrite_abc.RewriteResult:
(theta_stmt := py.Constant(0.0)).insert_before(node)
(phi_stmt := py.Constant(0.0)).insert_before(node)

node.replace_by(
gate_stmts.U3(
qubits=node.qubits,
theta=theta_stmt.result,
phi=phi_stmt.result,
lam=node.angle,
)
)

return rewrite_abc.RewriteResult(has_done_something=True)
143 changes: 143 additions & 0 deletions test/squin/rewrite/test_nonclifford_to_U3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from kirin import ir, rewrite
from kirin.dialects import py

from bloqade.squin.gate import stmts as gate_stmts
from bloqade.test_utils import assert_nodes
from bloqade.squin.rewrite.non_clifford_to_U3 import RewriteNonCliffordToU3


def test_rewrite_T():
test_qubits = ir.TestValue()
test_block = ir.Block([gate_stmts.T(qubits=test_qubits, adjoint=False)])

expected_block = ir.Block(
[
theta := py.Constant(0.0),
phi := py.Constant(0.0),
lam := py.Constant(1.0 / 8.0),
gate_stmts.U3(
qubits=test_qubits,
theta=theta.result,
phi=phi.result,
lam=lam.result,
),
]
)

rule = rewrite.Walk(RewriteNonCliffordToU3())
rule.rewrite(test_block)

assert_nodes(test_block, expected_block)


def test_rewrite_Tadj():
test_qubits = ir.TestValue()
test_block = ir.Block([gate_stmts.T(qubits=test_qubits, adjoint=True)])

expected_block = ir.Block(
[
theta := py.Constant(0.0),
phi := py.Constant(0.0),
lam := py.Constant(-1.0 / 8.0),
gate_stmts.U3(
qubits=test_qubits,
theta=theta.result,
phi=phi.result,
lam=lam.result,
),
]
)

rule = rewrite.Walk(RewriteNonCliffordToU3())
rule.rewrite(test_block)

assert_nodes(test_block, expected_block)


def test_rewrite_Ry():
test_qubits = ir.TestValue()
angle = ir.TestValue()
test_block = ir.Block([gate_stmts.Ry(qubits=test_qubits, angle=angle)])

expected_block = ir.Block(
[
phi := py.Constant(0.0),
lam := py.Constant(0.0),
gate_stmts.U3(
qubits=test_qubits,
theta=angle,
phi=phi.result,
lam=lam.result,
),
]
)

rule = rewrite.Walk(RewriteNonCliffordToU3())
rule.rewrite(test_block)

assert_nodes(test_block, expected_block)


def test_rewrite_Rz():
test_qubits = ir.TestValue()
angle = ir.TestValue()
test_block = ir.Block([gate_stmts.Rz(qubits=test_qubits, angle=angle)])

expected_block = ir.Block(
[
theta := py.Constant(0.0),
phi := py.Constant(0.0),
gate_stmts.U3(
qubits=test_qubits,
theta=theta.result,
phi=phi.result,
lam=angle,
),
]
)

rule = rewrite.Walk(RewriteNonCliffordToU3())
rule.rewrite(test_block)

assert_nodes(test_block, expected_block)


def test_rewrite_Rx():
test_qubits = ir.TestValue()
angle = ir.TestValue()
test_block = ir.Block([gate_stmts.Rx(qubits=test_qubits, angle=angle)])

expected_block = ir.Block(
[
phi := py.Constant(-0.25),
lam := py.Constant(0.25),
gate_stmts.U3(
qubits=test_qubits,
theta=angle,
phi=phi.result,
lam=lam.result,
),
]
)

rule = rewrite.Walk(RewriteNonCliffordToU3())
rule.rewrite(test_block)

assert_nodes(test_block, expected_block)


def test_no_op():
test_qubits = ir.TestValue()
angle = ir.TestValue()
test_block = ir.Block(
[gate_stmts.U3(qubits=test_qubits, theta=angle, phi=angle, lam=angle)]
)

expected_block = ir.Block(
[gate_stmts.U3(qubits=test_qubits, theta=angle, phi=angle, lam=angle)]
)

rule = rewrite.Walk(RewriteNonCliffordToU3())
rule.rewrite(test_block)

assert_nodes(test_block, expected_block)
Loading