Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/bloqade/analysis/address/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,22 @@ def for_loop(
if iter_type is None:
return interp_.eval_fallback(frame, stmt)

body_values = {}
for value in iterable:
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
loop_vars = interp_.frame_call_region(
body_frame, stmt, stmt.body, value, *loop_vars
)

for ssa, val in body_frame.entries.items():
body_values[ssa] = body_values.setdefault(ssa, val).join(val)

if loop_vars is None:
loop_vars = ()

elif isinstance(loop_vars, interp.ReturnValue):
frame.set_values(body_frame.entries.keys(), body_frame.entries.values())
return loop_vars

frame.set_values(body_values.keys(), body_values.values())
return loop_vars
105 changes: 105 additions & 0 deletions src/bloqade/squin/rewrite/non_clifford_to_U3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from kirin import ir
from kirin.rewrite import abc as rewrite_abc
from kirin.dialects import py

from bloqade.squin.gate import stmts as gate_stmts


class RewriteNonCliffordToU3(rewrite_abc.RewriteRule):
"""Rewrite non-Clifford gates to U3 gates.

This rewrite rule transforms specific non-Clifford single-qubit gates
into equivalent U3 gate representations. The following transformations are applied:
- T gate (with adjoint attribute) to U3 gate with parameters (0, 0, ±π/4)
- Rx gate to U3 gate with parameters (angle, -π/2, π/2)
- Ry gate to U3 gate with parameters (angle, 0, 0)
- Rz gate is U3 gate with parameters (0, 0, angle)

This rewrite should be paired with `U3ToClifford` to canonicalize the circuit.

"""

def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
if not isinstance(
node,
(
gate_stmts.T,
gate_stmts.Rx,
gate_stmts.Ry,
gate_stmts.Rz,
),
):
return rewrite_abc.RewriteResult()

rule = getattr(self, f"rewrite_{type(node).__name__}", self.default)

return rule(node)

def default(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
return rewrite_abc.RewriteResult()

def rewrite_T(self, node: gate_stmts.T) -> rewrite_abc.RewriteResult:
if node.adjoint:
lam_value = -1.0 / 8.0
else:
lam_value = 1.0 / 8.0

(theta_stmt := py.Constant(0.0)).insert_before(node)
(phi_stmt := py.Constant(0.0)).insert_before(node)
(lam_stmt := py.Constant(lam_value)).insert_before(node)

node.replace_by(
gate_stmts.U3(
qubits=node.qubits,
theta=theta_stmt.result,
phi=phi_stmt.result,
lam=lam_stmt.result,
)
)

return rewrite_abc.RewriteResult(has_done_something=True)

def rewrite_Rx(self, node: gate_stmts.Rx) -> rewrite_abc.RewriteResult:
(phi_stmt := py.Constant(-0.25)).insert_before(node)
(lam_stmt := py.Constant(0.25)).insert_before(node)

node.replace_by(
gate_stmts.U3(
qubits=node.qubits,
theta=node.angle,
phi=phi_stmt.result,
lam=lam_stmt.result,
)
)

return rewrite_abc.RewriteResult(has_done_something=True)

def rewrite_Ry(self, node: gate_stmts.Ry) -> rewrite_abc.RewriteResult:
(phi_stmt := py.Constant(0.0)).insert_before(node)
(lam_stmt := py.Constant(0.0)).insert_before(node)

node.replace_by(
gate_stmts.U3(
qubits=node.qubits,
theta=node.angle,
phi=phi_stmt.result,
lam=lam_stmt.result,
)
)

return rewrite_abc.RewriteResult(has_done_something=True)

def rewrite_Rz(self, node: gate_stmts.Rz) -> rewrite_abc.RewriteResult:
(theta_stmt := py.Constant(0.0)).insert_before(node)
(phi_stmt := py.Constant(0.0)).insert_before(node)

node.replace_by(
gate_stmts.U3(
qubits=node.qubits,
theta=theta_stmt.result,
phi=phi_stmt.result,
lam=node.angle,
)
)

return rewrite_abc.RewriteResult(has_done_something=True)
29 changes: 28 additions & 1 deletion test/analysis/address/test_qubit_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from util import collect_address_types
from kirin.analysis import const
from kirin.dialects import ilist
from kirin.dialects import scf, ilist

from bloqade import qubit, squin
from bloqade.analysis import address
Expand Down Expand Up @@ -265,3 +265,30 @@ def main():

assert ret == address.AddressReg(data=tuple(range(20)))
assert analysis.qubit_count == 20


def test_for_loop_body_values():
@squin.kernel
def main():
q = squin.qalloc(4)
for i in range(1, len(q)):
squin.cx(q[0], q[i])

address_analysis = address.AddressAnalysis(main.dialects)
frame, result = address_analysis.run(main)
main.print(analysis=frame.entries)

(for_stmt,) = tuple(
stmt for stmt in main.callable_region.walk() if isinstance(stmt, scf.For)
)

for_analysis = [
value
for stmt in for_stmt.body.walk()
for value in frame.get_values(stmt.results)
]

assert address.AddressQubit(data=0) in for_analysis
assert address.ConstResult(const.Value(0)) in for_analysis
assert address.ConstResult(const.Value(None)) in for_analysis
assert address.Unknown() in for_analysis
143 changes: 143 additions & 0 deletions test/squin/rewrite/test_nonclifford_to_U3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from kirin import ir, rewrite
from kirin.dialects import py

from bloqade.squin.gate import stmts as gate_stmts
from bloqade.test_utils import assert_nodes
from bloqade.squin.rewrite.non_clifford_to_U3 import RewriteNonCliffordToU3


def test_rewrite_T():
test_qubits = ir.TestValue()
test_block = ir.Block([gate_stmts.T(qubits=test_qubits, adjoint=False)])

expected_block = ir.Block(
[
theta := py.Constant(0.0),
phi := py.Constant(0.0),
lam := py.Constant(1.0 / 8.0),
gate_stmts.U3(
qubits=test_qubits,
theta=theta.result,
phi=phi.result,
lam=lam.result,
),
]
)

rule = rewrite.Walk(RewriteNonCliffordToU3())
rule.rewrite(test_block)

assert_nodes(test_block, expected_block)


def test_rewrite_Tadj():
test_qubits = ir.TestValue()
test_block = ir.Block([gate_stmts.T(qubits=test_qubits, adjoint=True)])

expected_block = ir.Block(
[
theta := py.Constant(0.0),
phi := py.Constant(0.0),
lam := py.Constant(-1.0 / 8.0),
gate_stmts.U3(
qubits=test_qubits,
theta=theta.result,
phi=phi.result,
lam=lam.result,
),
]
)

rule = rewrite.Walk(RewriteNonCliffordToU3())
rule.rewrite(test_block)

assert_nodes(test_block, expected_block)


def test_rewrite_Ry():
test_qubits = ir.TestValue()
angle = ir.TestValue()
test_block = ir.Block([gate_stmts.Ry(qubits=test_qubits, angle=angle)])

expected_block = ir.Block(
[
phi := py.Constant(0.0),
lam := py.Constant(0.0),
gate_stmts.U3(
qubits=test_qubits,
theta=angle,
phi=phi.result,
lam=lam.result,
),
]
)

rule = rewrite.Walk(RewriteNonCliffordToU3())
rule.rewrite(test_block)

assert_nodes(test_block, expected_block)


def test_rewrite_Rz():
test_qubits = ir.TestValue()
angle = ir.TestValue()
test_block = ir.Block([gate_stmts.Rz(qubits=test_qubits, angle=angle)])

expected_block = ir.Block(
[
theta := py.Constant(0.0),
phi := py.Constant(0.0),
gate_stmts.U3(
qubits=test_qubits,
theta=theta.result,
phi=phi.result,
lam=angle,
),
]
)

rule = rewrite.Walk(RewriteNonCliffordToU3())
rule.rewrite(test_block)

assert_nodes(test_block, expected_block)


def test_rewrite_Rx():
test_qubits = ir.TestValue()
angle = ir.TestValue()
test_block = ir.Block([gate_stmts.Rx(qubits=test_qubits, angle=angle)])

expected_block = ir.Block(
[
phi := py.Constant(-0.25),
lam := py.Constant(0.25),
gate_stmts.U3(
qubits=test_qubits,
theta=angle,
phi=phi.result,
lam=lam.result,
),
]
)

rule = rewrite.Walk(RewriteNonCliffordToU3())
rule.rewrite(test_block)

assert_nodes(test_block, expected_block)


def test_no_op():
test_qubits = ir.TestValue()
angle = ir.TestValue()
test_block = ir.Block(
[gate_stmts.U3(qubits=test_qubits, theta=angle, phi=angle, lam=angle)]
)

expected_block = ir.Block(
[gate_stmts.U3(qubits=test_qubits, theta=angle, phi=angle, lam=angle)]
)

rule = rewrite.Walk(RewriteNonCliffordToU3())
rule.rewrite(test_block)

assert_nodes(test_block, expected_block)
Loading