From 2f6e8a1d65127b1de98560cfae0d6454344e9979 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 2 Dec 2025 09:02:14 -0500 Subject: [PATCH 01/13] core dialects/functionality supported with pass --- src/bloqade/squin/passes/__init__.py | 1 + src/bloqade/squin/passes/qasm2_to_squin.py | 42 +++++ src/bloqade/squin/rewrite/qasm2/__init__.py | 3 + .../squin/rewrite/qasm2/core_to_squin.py | 53 ++++++ .../squin/rewrite/qasm2/expr_to_squin.py | 113 +++++++++++++ .../squin/rewrite/qasm2/uop_to_squin.py | 155 ++++++++++++++++++ 6 files changed, 367 insertions(+) create mode 100644 src/bloqade/squin/passes/__init__.py create mode 100644 src/bloqade/squin/passes/qasm2_to_squin.py create mode 100644 src/bloqade/squin/rewrite/qasm2/__init__.py create mode 100644 src/bloqade/squin/rewrite/qasm2/core_to_squin.py create mode 100644 src/bloqade/squin/rewrite/qasm2/expr_to_squin.py create mode 100644 src/bloqade/squin/rewrite/qasm2/uop_to_squin.py diff --git a/src/bloqade/squin/passes/__init__.py b/src/bloqade/squin/passes/__init__.py new file mode 100644 index 00000000..6dceb533 --- /dev/null +++ b/src/bloqade/squin/passes/__init__.py @@ -0,0 +1 @@ +from .qasm2_to_squin import QASM2ToSquin as QASM2ToSquin diff --git a/src/bloqade/squin/passes/qasm2_to_squin.py b/src/bloqade/squin/passes/qasm2_to_squin.py new file mode 100644 index 00000000..c81d254a --- /dev/null +++ b/src/bloqade/squin/passes/qasm2_to_squin.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass + +from kirin import ir +from kirin.passes import Fold, Pass, TypeInfer +from kirin.rewrite import Walk, Chain +from kirin.rewrite.abc import RewriteResult +from kirin.dialects.ilist.passes import IListDesugar + +from bloqade import squin +from bloqade.squin.rewrite.qasm2 import ( + QASM2UOPToSquin, + QASM2CoreToSquin, + QASM2ExprToSquin, +) + + +@dataclass +class QASM2ToSquin(Pass): + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + + # rewrite all QASM2 to squin first + rewrite_result = Walk( + Chain( + QASM2ExprToSquin(), + QASM2CoreToSquin(), + QASM2UOPToSquin(), + ) + ).rewrite(mt.code) + + # the rest is taken from the squin kernel + + rewrite_result = Fold(dialects=squin.kernel).fixpoint(mt) + rewrite_result = ( + TypeInfer(dialects=squin.kernel).unsafe_run(mt).join(rewrite_result) + ) + rewrite_result = ( + IListDesugar(dialects=squin.kernel).unsafe_run(mt).join(rewrite_result) + ) + TypeInfer(dialects=squin.kernel).unsafe_run(mt).join(rewrite_result) + + return rewrite_result diff --git a/src/bloqade/squin/rewrite/qasm2/__init__.py b/src/bloqade/squin/rewrite/qasm2/__init__.py new file mode 100644 index 00000000..40231fb5 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/__init__.py @@ -0,0 +1,3 @@ +from .uop_to_squin import QASM2UOPToSquin as QASM2UOPToSquin +from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin +from .expr_to_squin import QASM2ExprToSquin as QASM2ExprToSquin diff --git a/src/bloqade/squin/rewrite/qasm2/core_to_squin.py b/src/bloqade/squin/rewrite/qasm2/core_to_squin.py new file mode 100644 index 00000000..8418ce3d --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/core_to_squin.py @@ -0,0 +1,53 @@ +from kirin import ir +from kirin.dialects import py, func, ilist +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.types import MeasurementResultType +from bloqade.qasm2.dialects.core import stmts as core_stmts + + +class QASM2CoreToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + match node: + case core_stmts.QRegNew(): + return self.rewrite_QRegNew(node) + case core_stmts.CRegNew(): + return self.rewrite_CRegNew(node) + case core_stmts.Reset(): + return self.rewrite_Reset(node) + case core_stmts.QRegGet(): + return self.rewrite_Get(node) + case _: + return RewriteResult() + + def rewrite_QRegNew(self, stmt: core_stmts.QRegNew) -> RewriteResult: + qalloc_invoke_stmt = func.Invoke( + callee=squin.qubit.qalloc, inputs=(stmt.n_qubits,) + ) + stmt.replace_by(qalloc_invoke_stmt) + return RewriteResult(has_done_something=True) + + def rewrite_CRegNew(self, stmt: core_stmts.CRegNew) -> RewriteResult: + + measurement_list = ilist.New(values=(), elem_type=MeasurementResultType) + stmt.replace_by(measurement_list) + return RewriteResult(has_done_something=True) + + def rewrite_Reset(self, stmt: core_stmts.Reset) -> RewriteResult: + + squin_reset_stmt = squin.qubit.stmts.Reset(qubits=stmt.qarg) + stmt.replace_by(squin_reset_stmt) + return RewriteResult(has_done_something=True) + + def rewrite_Get(self, stmt: core_stmts.QRegGet) -> RewriteResult: + + get_item_stmt = py.GetItem( + obj=stmt.reg, + index=stmt.idx, + ) + + stmt.replace_by(get_item_stmt) + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/expr_to_squin.py b/src/bloqade/squin/rewrite/qasm2/expr_to_squin.py new file mode 100644 index 00000000..5b913562 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/expr_to_squin.py @@ -0,0 +1,113 @@ +from math import pi + +from kirin import ir +from kirin.dialects import py, math as kirin_math +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade.qasm2.dialects.expr import stmts as expr_stmts + +qasm2_binops = [] + + +class QASM2ExprToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + match node: + case expr_stmts.ConstInt() | expr_stmts.ConstFloat(): + return self.rewrite_Const(node) + case expr_stmts.ConstPI(): + return self.rewrite_PI(node) + case ( + expr_stmts.Mul() + | expr_stmts.Add() + | expr_stmts.Sub() + | expr_stmts.Div() + | expr_stmts.Pow() + ): + return self.rewrite_BinOp(node) + case ( + expr_stmts.Neg() + | expr_stmts.Sin() + | expr_stmts.Cos() + | expr_stmts.Tan() + | expr_stmts.Exp() + | expr_stmts.Log() + | expr_stmts.Sqrt() + ): + return self.rewrite_UnaryOp(node) + case _: + return RewriteResult() + + def rewrite_Const( + self, stmt: expr_stmts.ConstInt | expr_stmts.ConstFloat + ) -> RewriteResult: + + py_const = py.Constant(value=stmt.value) + stmt.replace_by(py_const) + return RewriteResult(has_done_something=True) + + def rewrite_PI(self, stmt: expr_stmts.ConstPI) -> RewriteResult: + + py_const = py.Constant(value=pi) + stmt.replace_by(py_const) + return RewriteResult(has_done_something=True) + + def rewrite_BinOp(self, stmt: ir.Statement) -> RewriteResult: + + match stmt: + case expr_stmts.Mul(): + op = py.binop.Mult + case expr_stmts.Add(): + op = py.binop.Add + case expr_stmts.Sub(): + op = py.binop.Sub + case expr_stmts.Div(): + op = py.binop.Div + case expr_stmts.Pow(): + op = py.binop.Pow + case _: + return RewriteResult() + + lhs = stmt.lhs + rhs = stmt.rhs + binop_expr = op(lhs=lhs, rhs=rhs) + stmt.replace_by(binop_expr) + return RewriteResult(has_done_something=True) + + def rewrite_UnaryOp( + self, + stmt: ( + expr_stmts.Neg + | expr_stmts.Sin + | expr_stmts.Cos + | expr_stmts.Tan + | expr_stmts.Exp + | expr_stmts.Log + | expr_stmts.Sqrt + ), + ) -> RewriteResult: + + match stmt: + case expr_stmts.Neg(): + op = py.unary.stmts.USub + case expr_stmts.Sin(): + op = kirin_math.stmts.sin + case expr_stmts.Cos(): + op = kirin_math.stmts.cos + case expr_stmts.Tan(): + op = kirin_math.stmts.tan + case expr_stmts.Exp(): + op = kirin_math.stmts.exp + case expr_stmts.Log(): + op = kirin_math.stmts.log2 + case expr_stmts.Sqrt(): + op = kirin_math.stmts.sqrt + case _: + return RewriteResult() + + value = stmt.value + unary_expr = op(value) + + stmt.replace_by(unary_expr) + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/uop_to_squin.py b/src/bloqade/squin/rewrite/qasm2/uop_to_squin.py new file mode 100644 index 00000000..0c3afc03 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/uop_to_squin.py @@ -0,0 +1,155 @@ +from math import pi + +from kirin import ir +from kirin.dialects import py, func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.qasm2.dialects.uop import stmts as uop_stmts + + +# assume that qasm2.core conversion has already run beforehand +class QASM2UOPToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + match node: + case uop_stmts.CX() | uop_stmts.CZ() | uop_stmts.CY(): + return self.rewrite_TwoQubitCtrlGate(node) + case ( + uop_stmts.X() + | uop_stmts.Y() + | uop_stmts.Z() + | uop_stmts.H() + | uop_stmts.S() + | uop_stmts.T() + | uop_stmts.SX() + ): + return self.rewrite_SingleQubitGate_no_parameters(node) + case uop_stmts.RZ() | uop_stmts.RX() | uop_stmts.RY(): + return self.rewrite_SingleQubit_with_parameters(node) + case uop_stmts.UGate() | uop_stmts.U1() | uop_stmts.U2(): + return self.rewrite_u_gates(node) + case _: + return RewriteResult() + + def rewrite_TwoQubitCtrlGate( + self, stmt: uop_stmts.CX | uop_stmts.CZ | uop_stmts.CY + ) -> RewriteResult: + + # qasm2 does not have broadcast semantics + # don't have to worry about these being lists + carg = stmt.ctrl + qarg = stmt.qarg + + match stmt: + case uop_stmts.CX(): + squin_2q_stdlib = squin.cx + case uop_stmts.CZ(): + squin_2q_stdlib = squin.cz + case uop_stmts.CY(): + squin_2q_stdlib = squin.cy + case _: + return RewriteResult() + + invoke_stmt = func.Invoke( + callee=squin_2q_stdlib, + inputs=(carg, qarg), + ) + + stmt.replace_by(invoke_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_SingleQubitGate_no_parameters( + self, + stmt: ( + uop_stmts.X + | uop_stmts.Y + | uop_stmts.Z + | uop_stmts.H + | uop_stmts.S + | uop_stmts.T + ), + ) -> RewriteResult: + + qarg = stmt.qarg + match stmt: + case uop_stmts.X(): + squin_1q_stdlib = squin.x + case uop_stmts.Y(): + squin_1q_stdlib = squin.y + case uop_stmts.Z(): + squin_1q_stdlib = squin.z + case uop_stmts.H(): + squin_1q_stdlib = squin.h + case uop_stmts.S(): + squin_1q_stdlib = squin.s + case uop_stmts.T(): + squin_1q_stdlib = squin.t + case uop_stmts.SX(): + squin_1q_stdlib = squin.sqrt_x + case _: + return RewriteResult() + + invoke_stmt = func.Invoke( + callee=squin_1q_stdlib, + inputs=(qarg,), + ) + + stmt.replace_by(invoke_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_u_gates( + self, stmt: uop_stmts.UGate | uop_stmts.U1 | uop_stmts.U2 + ) -> RewriteResult: + + match stmt: + case uop_stmts.UGate(lam=lam, phi=phi, theta=theta, qarg=qarg): + args = (theta, phi, lam, qarg) + case uop_stmts.U1(lam=lam, qarg=qarg): + zero_stmt = py.Constant(value=0.0) + zero_stmt.insert_before(stmt) + args = (zero_stmt.result, zero_stmt.result, lam, qarg) + case uop_stmts.U2(phi=phi, lam=lam, qarg=qarg): + half_pi_stmt = py.Constant(value=pi / 2) + half_pi_stmt.insert_before(stmt) + args = (half_pi_stmt.result, phi, lam, qarg) + case _: + return RewriteResult() + + invoke_stmt = func.Invoke( + callee=squin.u3, + inputs=args, + ) + + stmt.replace_by(invoke_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_SingleQubit_with_parameters( + self, stmt: uop_stmts.RZ | uop_stmts.RX | uop_stmts.RY + ) -> RewriteResult: + + qarg = stmt.qarg + theta = stmt.theta + + match stmt: + case uop_stmts.RZ(): + squin_1q_stdlib = squin.rz + case uop_stmts.RX(): + squin_1q_stdlib = squin.rx + case uop_stmts.RY(): + squin_1q_stdlib = squin.ry + case _: + return RewriteResult() + + invoke_stmt = func.Invoke( + callee=squin_1q_stdlib, + inputs=(qarg, theta), + ) + + stmt.replace_by(invoke_stmt) + + return RewriteResult(has_done_something=True) From 90d857864771ef70f1352fe0e7cabdf3358cd537 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 2 Dec 2025 14:58:24 -0500 Subject: [PATCH 02/13] parallel, global, and noise dialects now supported --- src/bloqade/squin/passes/qasm2_to_squin.py | 4 + src/bloqade/squin/rewrite/qasm2/__init__.py | 2 + .../rewrite/qasm2/glob_parallel_to_squin.py | 51 +++++++++ .../squin/rewrite/qasm2/noise_to_squin.py | 107 ++++++++++++++++++ test/squin/rewrite/test_qasm2_to_squin.py | 104 +++++++++++++++++ 5 files changed, 268 insertions(+) create mode 100644 src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py create mode 100644 src/bloqade/squin/rewrite/qasm2/noise_to_squin.py create mode 100644 test/squin/rewrite/test_qasm2_to_squin.py diff --git a/src/bloqade/squin/passes/qasm2_to_squin.py b/src/bloqade/squin/passes/qasm2_to_squin.py index c81d254a..ce0831dc 100644 --- a/src/bloqade/squin/passes/qasm2_to_squin.py +++ b/src/bloqade/squin/passes/qasm2_to_squin.py @@ -11,6 +11,8 @@ QASM2UOPToSquin, QASM2CoreToSquin, QASM2ExprToSquin, + QASM2NoiseToSquin, + QASM2GlobParallelToSquin, ) @@ -25,6 +27,8 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: QASM2ExprToSquin(), QASM2CoreToSquin(), QASM2UOPToSquin(), + QASM2GlobParallelToSquin(), + QASM2NoiseToSquin(), ) ).rewrite(mt.code) diff --git a/src/bloqade/squin/rewrite/qasm2/__init__.py b/src/bloqade/squin/rewrite/qasm2/__init__.py index 40231fb5..c2355fde 100644 --- a/src/bloqade/squin/rewrite/qasm2/__init__.py +++ b/src/bloqade/squin/rewrite/qasm2/__init__.py @@ -1,3 +1,5 @@ from .uop_to_squin import QASM2UOPToSquin as QASM2UOPToSquin from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin from .expr_to_squin import QASM2ExprToSquin as QASM2ExprToSquin +from .noise_to_squin import QASM2NoiseToSquin as QASM2NoiseToSquin +from .glob_parallel_to_squin import QASM2GlobParallelToSquin as QASM2GlobParallelToSquin diff --git a/src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py b/src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py new file mode 100644 index 00000000..65322d00 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py @@ -0,0 +1,51 @@ +from kirin import ir +from kirin.dialects import func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.qasm2.dialects import glob, parallel + + +class QASM2GlobParallelToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + match node: + case glob.UGate() | parallel.UGate() | parallel.RZ(): + return self.rewrite_1q_gates(node) + case _: + return RewriteResult() + + return RewriteResult(has_done_something=True) + + def rewrite_1q_gates( + self, stmt: glob.UGate | parallel.UGate | parallel.RZ + ) -> RewriteResult: + + match stmt: + case glob.UGate(theta=theta, phi=phi, lam=lam) | parallel.UGate( + theta=theta, phi=phi, lam=lam + ): + # ever so slight naming difference, + # exists because intended semantics are different + match stmt: + case glob.UGate(): + qargs = stmt.registers + case parallel.UGate(): + qargs = stmt.qargs + + invoke_u_broadcast_stmt = func.Invoke( + callee=squin.broadcast.u3, + inputs=(theta, phi, lam, qargs), + ) + stmt.replace_by(invoke_u_broadcast_stmt) + case parallel.RZ(theta=theta, qargs=qargs): + invoke_rz_broadcast_stmt = func.Invoke( + callee=squin.broadcast.rz, + inputs=(theta, qargs), + ) + stmt.replace_by(invoke_rz_broadcast_stmt) + case _: + return RewriteResult() + + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py b/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py new file mode 100644 index 00000000..625d9a29 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py @@ -0,0 +1,107 @@ +from kirin import ir +from kirin.dialects import py, func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.qasm2.dialects.noise import stmts as noise_stmts + + +class QASM2NoiseToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + match node: + case noise_stmts.AtomLossChannel(): + return self.rewrite_AtomLossChannel(node) + case noise_stmts.PauliChannel(): + return self.rewrite_PauliChannel(node) + case noise_stmts.CZPauliChannel(): + return self.rewrite_CZPauliChannel(node) + case _: + return RewriteResult() + + return RewriteResult() + + def rewrite_AtomLossChannel( + self, stmt: noise_stmts.AtomLossChannel + ) -> RewriteResult: + + qargs = stmt.qargs + # this is a raw float, not in SSA form yet! + prob = stmt.prob + prob_stmt = py.Constant(value=prob) + prob_stmt.insert_before(stmt) + + invoke_loss_stmt = func.Invoke( + callee=squin.broadcast.qubit_loss, + inputs=(prob_stmt.result, qargs), + ) + + stmt.replace_by(invoke_loss_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_PauliChannel(self, stmt: noise_stmts.PauliChannel) -> RewriteResult: + + qargs = stmt.qargs + p_x = stmt.px + p_y = stmt.py + p_z = stmt.pz + + probs = [p_x, p_y, p_z] + probs_ssas = [] + + for prob in probs: + prob_stmt = py.Constant(value=prob) + prob_stmt.insert_before(stmt) + probs_ssas.append(prob_stmt.result) + + invoke_pauli_channel_stmt = func.Invoke( + callee=squin.broadcast.single_qubit_pauli_channel, + inputs=(*probs_ssas, qargs), + ) + + stmt.replace_by(invoke_pauli_channel_stmt) + return RewriteResult(has_done_something=True) + + def rewrite_CZPauliChannel(self, stmt: noise_stmts.CZPauliChannel) -> RewriteResult: + + ctrls = stmt.ctrls + qargs = stmt.qargs + + px_ctrl = stmt.px_ctrl + py_ctrl = stmt.py_ctrl + pz_ctrl = stmt.pz_ctrl + px_qarg = stmt.px_qarg + py_qarg = stmt.py_qarg + pz_qarg = stmt.pz_qarg + + error_probs = [px_ctrl, py_ctrl, pz_ctrl, px_qarg, py_qarg, pz_qarg] + # first half of entries for control qubits, other half for targets + error_prob_ssas = [] + for error_prob in error_probs: + error_prob_stmt = py.Constant(value=error_prob) + error_prob_stmt.insert_before(stmt) + error_prob_ssas.append(error_prob_stmt.result) + + ctrl_pauli_channel_invoke = func.Invoke( + callee=squin.broadcast.single_qubit_pauli_channel, + inputs=( + *error_prob_ssas[:3], + ctrls, + ), + ) + + qarg_pauli_channel_invoke = func.Invoke( + callee=squin.broadcast.single_qubit_pauli_channel, + inputs=( + *error_prob_ssas[3:], + qargs, + ), + ) + + ctrl_pauli_channel_invoke.insert_before(stmt) + qarg_pauli_channel_invoke.insert_before(stmt) + stmt.delete() + + return RewriteResult(has_done_something=True) diff --git a/test/squin/rewrite/test_qasm2_to_squin.py b/test/squin/rewrite/test_qasm2_to_squin.py new file mode 100644 index 00000000..6cc75817 --- /dev/null +++ b/test/squin/rewrite/test_qasm2_to_squin.py @@ -0,0 +1,104 @@ +from kirin.rewrite import Walk, Chain +from kirin.passes.inline import InlinePass + +from bloqade import qasm2, squin +from bloqade.qasm2 import glob, noise, parallel + +# from kirin import prelude +# from kirin.passes import Fold, TypeInfer +# from kirin.dialects.ilist.passes import IListDesugar +from bloqade.squin.passes import QASM2ToSquin +from bloqade.squin.rewrite.qasm2 import ( + QASM2CoreToSquin, + QASM2ExprToSquin, +) + + +def test_expr_rewrite(): + + @qasm2.main + def expr_program(): + x = 0 + z = -1.7 + y = x + z + qasm2.sin(y) + return y + + Walk(QASM2ExprToSquin()).rewrite(expr_program.code) + + expr_program.print() + + +def test_qasm2_core(): + + @qasm2.main(fold=False) + def measure_kern(): + q = qasm2.qreg(5) + q0 = q[0] + return q0 + + measure_kern.print() + + Walk(Chain(QASM2ExprToSquin(), QASM2CoreToSquin())).rewrite(measure_kern.code) + + measure_kern.print() + + +def test_gates(): + + @qasm2.main + def gate_program(): + q = qasm2.qreg(10) + qasm2.cx(q[0], q[1]) + qasm2.z(q[3]) + qasm2.u3(q[4], 1.57, 0.0, 3.14) + qasm2.u1(q[5], 0.78) + qasm2.rz(q[6], 2.34) + qasm2.sx(q[7]) + return q + + gate_program.print() + QASM2ToSquin(dialects=squin.kernel)(gate_program) + gate_program.print() + + +def test_noise(): + + @qasm2.extended + def noise_program(): + q = qasm2.qreg(4) + noise.atom_loss_channel(qargs=q, prob=0.05) + noise.pauli_channel(qargs=q, px=0.01, py=0.02, pz=0.03) + noise.cz_pauli_channel( + ctrls=[q[0], q[1]], + qargs=[q[2], q[3]], + px_ctrl=0.01, + py_ctrl=0.02, + pz_ctrl=0.03, + px_qarg=0.04, + py_qarg=0.05, + pz_qarg=0.06, + paired=True, + ) + return q + + noise_program.print() + QASM2ToSquin(dialects=squin.kernel)(noise_program) + InlinePass(dialects=squin.kernel)(noise_program) + + +def test_global_and_parallel(): + + @qasm2.extended + def global_parallel_program(): + q = qasm2.qreg(6) + parallel.u([q[0]], 1.0, 0.0, 3.14) + glob.u([q[1], q[2], q[3]], 0.5, 1.57, 0.0) + parallel.rz([q[4], q[5]], 2.34) + + return q + + global_parallel_program.print() + QASM2ToSquin(dialects=squin.kernel)(global_parallel_program) + InlinePass(dialects=squin.kernel)(global_parallel_program) + global_parallel_program.print() From 509b9cc5c29b4f549cc97b6f5bcc92df45757a24 Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 4 Dec 2025 08:35:44 -0500 Subject: [PATCH 03/13] wip on converting invoke/functions --- src/bloqade/squin/passes/qasm2_to_squin.py | 4 + src/bloqade/squin/rewrite/qasm2/__init__.py | 1 + .../squin/rewrite/qasm2/core_to_squin.py | 54 +-- .../squin/rewrite/qasm2/func_to_squin.py | 25 ++ .../squin/rewrite/qasm2/uop_to_squin.py | 13 +- test/squin/passes/test_qasm2_to_squin.py | 341 ++++++++++++++++++ test/squin/rewrite/test_qasm2_to_squin.py | 104 ------ 7 files changed, 400 insertions(+), 142 deletions(-) create mode 100644 src/bloqade/squin/rewrite/qasm2/func_to_squin.py create mode 100644 test/squin/passes/test_qasm2_to_squin.py delete mode 100644 test/squin/rewrite/test_qasm2_to_squin.py diff --git a/src/bloqade/squin/passes/qasm2_to_squin.py b/src/bloqade/squin/passes/qasm2_to_squin.py index ce0831dc..fa31e8e3 100644 --- a/src/bloqade/squin/passes/qasm2_to_squin.py +++ b/src/bloqade/squin/passes/qasm2_to_squin.py @@ -11,6 +11,7 @@ QASM2UOPToSquin, QASM2CoreToSquin, QASM2ExprToSquin, + QASM2FuncToSquin, QASM2NoiseToSquin, QASM2GlobParallelToSquin, ) @@ -24,6 +25,7 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: # rewrite all QASM2 to squin first rewrite_result = Walk( Chain( + QASM2FuncToSquin(), QASM2ExprToSquin(), QASM2CoreToSquin(), QASM2UOPToSquin(), @@ -43,4 +45,6 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: ) TypeInfer(dialects=squin.kernel).unsafe_run(mt).join(rewrite_result) + mt.dialects = squin.kernel + return rewrite_result diff --git a/src/bloqade/squin/rewrite/qasm2/__init__.py b/src/bloqade/squin/rewrite/qasm2/__init__.py index c2355fde..7193afeb 100644 --- a/src/bloqade/squin/rewrite/qasm2/__init__.py +++ b/src/bloqade/squin/rewrite/qasm2/__init__.py @@ -1,5 +1,6 @@ from .uop_to_squin import QASM2UOPToSquin as QASM2UOPToSquin from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin from .expr_to_squin import QASM2ExprToSquin as QASM2ExprToSquin +from .func_to_squin import QASM2FuncToSquin as QASM2FuncToSquin from .noise_to_squin import QASM2NoiseToSquin as QASM2NoiseToSquin from .glob_parallel_to_squin import QASM2GlobParallelToSquin as QASM2GlobParallelToSquin diff --git a/src/bloqade/squin/rewrite/qasm2/core_to_squin.py b/src/bloqade/squin/rewrite/qasm2/core_to_squin.py index 8418ce3d..27ea5435 100644 --- a/src/bloqade/squin/rewrite/qasm2/core_to_squin.py +++ b/src/bloqade/squin/rewrite/qasm2/core_to_squin.py @@ -1,9 +1,8 @@ from kirin import ir -from kirin.dialects import py, func, ilist +from kirin.dialects import py, func from kirin.rewrite.abc import RewriteRule, RewriteResult from bloqade import squin -from bloqade.types import MeasurementResultType from bloqade.qasm2.dialects.core import stmts as core_stmts @@ -12,42 +11,23 @@ class QASM2CoreToSquin(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: match node: - case core_stmts.QRegNew(): - return self.rewrite_QRegNew(node) - case core_stmts.CRegNew(): - return self.rewrite_CRegNew(node) - case core_stmts.Reset(): - return self.rewrite_Reset(node) - case core_stmts.QRegGet(): - return self.rewrite_Get(node) + case core_stmts.QRegNew(n_qubits=n_qubits): + qalloc_invoke_stmt = func.Invoke( + callee=squin.qubit.qalloc, inputs=(n_qubits,) + ) + node.replace_by(qalloc_invoke_stmt) + case core_stmts.Reset(qarg=qarg): + reset_invoke_stmt = func.Invoke( + callee=squin.qubit.reset, inputs=(qarg,) + ) + node.replace_by(reset_invoke_stmt) + case core_stmts.QRegGet(reg=reg, idx=idx): + get_item_stmt = py.GetItem( + obj=reg, + index=idx, + ) + node.replace_by(get_item_stmt) case _: return RewriteResult() - def rewrite_QRegNew(self, stmt: core_stmts.QRegNew) -> RewriteResult: - qalloc_invoke_stmt = func.Invoke( - callee=squin.qubit.qalloc, inputs=(stmt.n_qubits,) - ) - stmt.replace_by(qalloc_invoke_stmt) - return RewriteResult(has_done_something=True) - - def rewrite_CRegNew(self, stmt: core_stmts.CRegNew) -> RewriteResult: - - measurement_list = ilist.New(values=(), elem_type=MeasurementResultType) - stmt.replace_by(measurement_list) - return RewriteResult(has_done_something=True) - - def rewrite_Reset(self, stmt: core_stmts.Reset) -> RewriteResult: - - squin_reset_stmt = squin.qubit.stmts.Reset(qubits=stmt.qarg) - stmt.replace_by(squin_reset_stmt) - return RewriteResult(has_done_something=True) - - def rewrite_Get(self, stmt: core_stmts.QRegGet) -> RewriteResult: - - get_item_stmt = py.GetItem( - obj=stmt.reg, - index=stmt.idx, - ) - - stmt.replace_by(get_item_stmt) return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/func_to_squin.py b/src/bloqade/squin/rewrite/qasm2/func_to_squin.py new file mode 100644 index 00000000..55ca9673 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/func_to_squin.py @@ -0,0 +1,25 @@ +from kirin import ir +from kirin.dialects import func +from kirin.rewrite.abc import RewriteRule, RewriteResult + + +class QASM2FuncToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + if not isinstance(node, func.Invoke): + return RewriteResult() + + callee = node.callee + + return self.rewrite_Region(callee.callable_region) + + def rewrite_Region(self, region: ir.Region) -> RewriteResult: + + rewrite_result = RewriteResult() + + for stmt in list(region.walk()): + result = self.rewrite_Statement(stmt) + rewrite_result = rewrite_result.join(result) + + return rewrite_result diff --git a/src/bloqade/squin/rewrite/qasm2/uop_to_squin.py b/src/bloqade/squin/rewrite/qasm2/uop_to_squin.py index 0c3afc03..7a905d14 100644 --- a/src/bloqade/squin/rewrite/qasm2/uop_to_squin.py +++ b/src/bloqade/squin/rewrite/qasm2/uop_to_squin.py @@ -30,6 +30,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: return self.rewrite_SingleQubit_with_parameters(node) case uop_stmts.UGate() | uop_stmts.U1() | uop_stmts.U2(): return self.rewrite_u_gates(node) + case uop_stmts.Id(): + return self.rewrite_Id(node) case _: return RewriteResult() @@ -147,9 +149,18 @@ def rewrite_SingleQubit_with_parameters( invoke_stmt = func.Invoke( callee=squin_1q_stdlib, - inputs=(qarg, theta), + inputs=(theta, qarg), ) stmt.replace_by(invoke_stmt) return RewriteResult(has_done_something=True) + + def rewrite_Id(self, stmt: uop_stmts.Id) -> RewriteResult: + + # Identity does not exist in squin, + # we can just remove it from the program + + stmt.delete() + + return RewriteResult(has_done_something=True) diff --git a/test/squin/passes/test_qasm2_to_squin.py b/test/squin/passes/test_qasm2_to_squin.py new file mode 100644 index 00000000..01e7fea6 --- /dev/null +++ b/test/squin/passes/test_qasm2_to_squin.py @@ -0,0 +1,341 @@ +import math + +from kirin import types as kirin_types +from kirin.rewrite import Walk +from kirin.dialects import py, func, math as kirin_math +from kirin.dialects.ilist import IListType + +from bloqade import qasm2, squin +from bloqade.qasm2 import glob, noise, parallel +from bloqade.types import QubitType +from bloqade.squin.passes import QASM2ToSquin +from bloqade.rewrite.passes import AggressiveUnroll +from bloqade.analysis.address import AddressAnalysis +from bloqade.squin.rewrite.qasm2 import ( + QASM2ExprToSquin, +) +from bloqade.analysis.address.lattice import AddressReg + + +def test_expr_rewrite(): + + @qasm2.main + def expr_program(): + # constants + x = 0 # noqa: F841 + # ConstPI only added from lowering + y = qasm2.dialects.expr.stmts.ConstPI() # noqa: F841 + z = -1.75 # noqa: F841 + # binary ops + a = 1 + 1 # noqa: F841 + b = 2 * 2 # noqa: F841 + c = 3 - 3 # noqa: F841 + d = 4 / 4 # noqa: F841 + e = 5**2 # noqa: F841 + + # math + a = 0.2 + qasm2.sin(a) + qasm2.cos(a) + qasm2.tan(a) + qasm2.exp(a) + qasm2.ln(a) + qasm2.sqrt(a) + return + + expr_program.print() + + Walk(QASM2ExprToSquin()).rewrite(expr_program.code) + + expr_program.print() + + actual_stmt_sequence = list(expr_program.callable_region.walk()) + + def is_pi_const(stmt: py.Constant): + return isinstance(stmt, py.Constant) and math.isclose(stmt.value.data, math.pi) + + assert any(is_pi_const(stmt) for stmt in actual_stmt_sequence) + + assert qasm2.expr.ConstFloat not in actual_stmt_sequence + assert qasm2.expr.ConstInt not in actual_stmt_sequence + assert qasm2.expr.ConstPI not in actual_stmt_sequence + + no_const_actual_sequence = [ + type(stmt) + for stmt in actual_stmt_sequence + if not isinstance(stmt, (py.Constant, func.ConstantNone, func.Return)) + ] + + expected_stmt_sequence = [ + py.unary.stmts.USub, + py.binop.Add, + py.binop.Mult, + py.binop.Sub, + py.binop.Div, + py.binop.Pow, + kirin_math.stmts.sin, + kirin_math.stmts.cos, + kirin_math.stmts.tan, + kirin_math.stmts.exp, + kirin_math.stmts.log2, + kirin_math.stmts.sqrt, + ] + + assert no_const_actual_sequence == expected_stmt_sequence + + +def test_qasm2_core(): + + @qasm2.main(fold=False) + def core_kernel(): + q = qasm2.qreg(5) + q0 = q[0] + qasm2.reset(q0) + return q0 + + QASM2ToSquin(dialects=squin.kernel)(core_kernel) + + stmts = list(core_kernel.callable_region.walk()) + get_item_stmt = [stmt for stmt in stmts if isinstance(stmt, py.GetItem)] + assert len(get_item_stmt) == 1 + get_item_stmt = get_item_stmt[0] + assert get_item_stmt.obj.type == IListType[QubitType, kirin_types.Any] + idx_const = get_item_stmt.index.owner + assert idx_const.value.data == 0 + + # do aggressive unroll, confirm there are 5 qubit.news() + AggressiveUnroll(dialects=squin.kernel).fixpoint(core_kernel) + + unrolled_stmts = list(core_kernel.callable_region.walk()) + filtered_stmts = [ + stmt + for stmt in unrolled_stmts + if isinstance(stmt, (squin.qubit.stmts.New, squin.qubit.stmts.Reset)) + ] + expected_stmts = [squin.qubit.stmts.New] * 5 + [squin.qubit.stmts.Reset] + + assert [type(stmt) for stmt in filtered_stmts] == expected_stmts + + +def test_non_parametric_gates(): + + @qasm2.main + def non_parametric_gates(): + # 1q gates + q = qasm2.qreg(1) + qasm2.id(q[0]) + qasm2.x(q[0]) + qasm2.y(q[0]) + qasm2.z(q[0]) + qasm2.h(q[0]) + qasm2.s(q[0]) + qasm2.t(q[0]) + qasm2.sx(q[0]) + + # 2q gates + qasm2.cx(q[0], q[1]) + qasm2.cy(q[0], q[1]) + qasm2.cz(q[0], q[1]) + + return q + + non_parametric_gates.print() + QASM2ToSquin(dialects=squin.kernel)(non_parametric_gates) + AggressiveUnroll(dialects=squin.kernel).fixpoint(non_parametric_gates) + + actual_stmts = list(non_parametric_gates.callable_region.walk()) + # should be no identity whatsoever + assert not any(isinstance(stmt, qasm2.uop.stmts.Id) for stmt in actual_stmts) + actual_stmts = [ + stmt for stmt in actual_stmts if isinstance(stmt, squin.gate.stmts.Gate) + ] + expected_stmts = [ + squin.gate.stmts.X, + squin.gate.stmts.Y, + squin.gate.stmts.Z, + squin.gate.stmts.H, + squin.gate.stmts.S, + squin.gate.stmts.T, + squin.gate.stmts.SqrtX, + squin.gate.stmts.CX, + squin.gate.stmts.CY, + squin.gate.stmts.CZ, + ] + + assert [type(stmt) for stmt in actual_stmts] == expected_stmts + + +def test_parametric_gates(): + + const_pi = math.pi + + @qasm2.main + def rotation_gates(): + q = qasm2.qreg(3) + half_turn = const_pi + quarter_turn = const_pi / 2 + eighth_turn = const_pi / 4 + qasm2.rz(q[0], half_turn) + qasm2.rx(q[1], half_turn) + qasm2.ry(q[2], half_turn) + + qasm2.u3(q[0], half_turn, quarter_turn, eighth_turn) + qasm2.u2(q[0], quarter_turn, half_turn) + qasm2.u1(q[0], eighth_turn) + return q + + rotation_gates.print() + QASM2ToSquin(dialects=squin.kernel)(rotation_gates) + AggressiveUnroll(dialects=squin.kernel).fixpoint(rotation_gates) + rotation_gates.print() + + actual_stmts = list(rotation_gates.callable_region.walk()) + actual_stmts = [ + stmt for stmt in actual_stmts if isinstance(stmt, squin.gate.stmts.Gate) + ] + + assert len(actual_stmts) == 6 + + assert type(actual_stmts[0]) is squin.gate.stmts.Rz + assert actual_stmts[0].angle.owner.value.data == 0.5 + assert type(actual_stmts[1]) is squin.gate.stmts.Rx + assert actual_stmts[1].angle.owner.value.data == 0.5 + assert type(actual_stmts[2]) is squin.gate.stmts.Ry + assert actual_stmts[2].angle.owner.value.data == 0.5 + + # U3 + assert type(actual_stmts[3]) is squin.gate.stmts.U3 + assert actual_stmts[3].theta.owner.value.data == 0.5 + assert actual_stmts[3].phi.owner.value.data == 0.25 + assert actual_stmts[3].lam.owner.value.data == 0.125 + + # U2 is just U3(pi/2, phi, lam) + assert type(actual_stmts[4]) is squin.gate.stmts.U3 + assert actual_stmts[4].theta.owner.value.data == 0.25 + assert actual_stmts[4].phi.owner.value.data == 0.25 + assert actual_stmts[4].lam.owner.value.data == 0.5 + + assert type(actual_stmts[5]) is squin.gate.stmts.U3 + assert actual_stmts[5].theta.owner.value.data == 0.0 + assert actual_stmts[5].phi.owner.value.data == 0.0 + assert actual_stmts[5].lam.owner.value.data == 0.125 + + +def test_noise(): + + @qasm2.extended + def noise_program(): + q = qasm2.qreg(4) + noise.atom_loss_channel(qargs=q, prob=0.05) + noise.pauli_channel(qargs=q, px=0.01, py=0.02, pz=0.03) + noise.cz_pauli_channel( + ctrls=[q[0], q[1]], + qargs=[q[2], q[3]], + px_ctrl=0.01, + py_ctrl=0.02, + pz_ctrl=0.03, + px_qarg=0.04, + py_qarg=0.05, + pz_qarg=0.06, + paired=True, + ) + return q + + noise_program.print() + QASM2ToSquin(dialects=noise_program.dialects)(noise_program) + AggressiveUnroll(dialects=noise_program.dialects).fixpoint(noise_program) + frame, _ = AddressAnalysis(dialects=noise_program.dialects).run(noise_program) + + actual_stmts = list(noise_program.callable_region.walk()) + actual_stmts = [ + stmt + for stmt in actual_stmts + if isinstance(stmt, squin.noise.stmts.NoiseChannel) + ] + + assert type(actual_stmts[0]) is squin.noise.stmts.QubitLoss + assert actual_stmts[0].p.owner.value.data == 0.05 + + assert type(actual_stmts[1]) is squin.noise.stmts.SingleQubitPauliChannel + assert actual_stmts[1].px.owner.value.data == 0.01 + assert actual_stmts[1].py.owner.value.data == 0.02 + assert actual_stmts[1].pz.owner.value.data == 0.03 + + # originate from the same cz_pauli_channel + assert type(actual_stmts[2]) is squin.noise.stmts.SingleQubitPauliChannel + assert type(actual_stmts[3]) is squin.noise.stmts.SingleQubitPauliChannel + # control qubits + assert frame.get(actual_stmts[2].qubits) == AddressReg(data=(0, 1)) + # target qubits + assert frame.get(actual_stmts[3].qubits) == AddressReg(data=(2, 3)) + # assert probabilities are correct on control + assert actual_stmts[2].px.owner.value.data == 0.01 + assert actual_stmts[2].py.owner.value.data == 0.02 + assert actual_stmts[2].pz.owner.value.data == 0.03 + # assert probabilities are correct on targets + assert actual_stmts[3].px.owner.value.data == 0.04 + assert actual_stmts[3].py.owner.value.data == 0.05 + assert actual_stmts[3].pz.owner.value.data == 0.06 + + +def test_global_and_parallel(): + + const_pi = math.pi + + @qasm2.extended + def global_parallel_program(): + q = qasm2.qreg(6) + half_turn = const_pi + quarter_turn = const_pi / 2 + eighth_turn = const_pi / 4 + + parallel.u([q[0]], half_turn, quarter_turn, eighth_turn) + glob.u([q[2], q[1], q[3]], half_turn, quarter_turn, eighth_turn) + parallel.rz([q[4], q[5]], eighth_turn) + return q + + global_parallel_program.print() + QASM2ToSquin(dialects=global_parallel_program.dialects)(global_parallel_program) + AggressiveUnroll(dialects=global_parallel_program.dialects).fixpoint( + global_parallel_program + ) + + actual_stmts = list(global_parallel_program.callable_region.walk()) + actual_stmts = [ + stmt for stmt in actual_stmts if isinstance(stmt, squin.gate.stmts.Gate) + ] + + assert type(actual_stmts[0]) is squin.gate.stmts.U3 + assert actual_stmts[0].theta.owner.value.data == 0.5 + assert actual_stmts[0].phi.owner.value.data == 0.25 + assert actual_stmts[0].lam.owner.value.data == 0.125 + + assert type(actual_stmts[1]) is squin.gate.stmts.U3 + assert actual_stmts[1].theta.owner.value.data == 0.5 + assert actual_stmts[1].phi.owner.value.data == 0.25 + assert actual_stmts[1].lam.owner.value.data == 0.125 + + assert type(actual_stmts[2]) is squin.gate.stmts.Rz + assert actual_stmts[2].angle.owner.value.data == 0.125 + + +def test_func(): + + @qasm2.main + def sub_kernel(ctrl, target): + qasm2.cx(ctrl, target) + return + + @qasm2.main + def main_kernel(): + q = qasm2.qreg(2) + sub_kernel(q[0], q[1]) + return q + + main_kernel.print() + QASM2ToSquin(dialects=main_kernel.dialects)(main_kernel) + AggressiveUnroll(dialects=main_kernel.dialects).fixpoint(main_kernel) + main_kernel.print() + + +test_func() diff --git a/test/squin/rewrite/test_qasm2_to_squin.py b/test/squin/rewrite/test_qasm2_to_squin.py deleted file mode 100644 index 6cc75817..00000000 --- a/test/squin/rewrite/test_qasm2_to_squin.py +++ /dev/null @@ -1,104 +0,0 @@ -from kirin.rewrite import Walk, Chain -from kirin.passes.inline import InlinePass - -from bloqade import qasm2, squin -from bloqade.qasm2 import glob, noise, parallel - -# from kirin import prelude -# from kirin.passes import Fold, TypeInfer -# from kirin.dialects.ilist.passes import IListDesugar -from bloqade.squin.passes import QASM2ToSquin -from bloqade.squin.rewrite.qasm2 import ( - QASM2CoreToSquin, - QASM2ExprToSquin, -) - - -def test_expr_rewrite(): - - @qasm2.main - def expr_program(): - x = 0 - z = -1.7 - y = x + z - qasm2.sin(y) - return y - - Walk(QASM2ExprToSquin()).rewrite(expr_program.code) - - expr_program.print() - - -def test_qasm2_core(): - - @qasm2.main(fold=False) - def measure_kern(): - q = qasm2.qreg(5) - q0 = q[0] - return q0 - - measure_kern.print() - - Walk(Chain(QASM2ExprToSquin(), QASM2CoreToSquin())).rewrite(measure_kern.code) - - measure_kern.print() - - -def test_gates(): - - @qasm2.main - def gate_program(): - q = qasm2.qreg(10) - qasm2.cx(q[0], q[1]) - qasm2.z(q[3]) - qasm2.u3(q[4], 1.57, 0.0, 3.14) - qasm2.u1(q[5], 0.78) - qasm2.rz(q[6], 2.34) - qasm2.sx(q[7]) - return q - - gate_program.print() - QASM2ToSquin(dialects=squin.kernel)(gate_program) - gate_program.print() - - -def test_noise(): - - @qasm2.extended - def noise_program(): - q = qasm2.qreg(4) - noise.atom_loss_channel(qargs=q, prob=0.05) - noise.pauli_channel(qargs=q, px=0.01, py=0.02, pz=0.03) - noise.cz_pauli_channel( - ctrls=[q[0], q[1]], - qargs=[q[2], q[3]], - px_ctrl=0.01, - py_ctrl=0.02, - pz_ctrl=0.03, - px_qarg=0.04, - py_qarg=0.05, - pz_qarg=0.06, - paired=True, - ) - return q - - noise_program.print() - QASM2ToSquin(dialects=squin.kernel)(noise_program) - InlinePass(dialects=squin.kernel)(noise_program) - - -def test_global_and_parallel(): - - @qasm2.extended - def global_parallel_program(): - q = qasm2.qreg(6) - parallel.u([q[0]], 1.0, 0.0, 3.14) - glob.u([q[1], q[2], q[3]], 0.5, 1.57, 0.0) - parallel.rz([q[4], q[5]], 2.34) - - return q - - global_parallel_program.print() - QASM2ToSquin(dialects=squin.kernel)(global_parallel_program) - InlinePass(dialects=squin.kernel)(global_parallel_program) - global_parallel_program.print() From 1fa9865de6b714b5903020a75cfb088da78a109f Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 4 Dec 2025 09:12:41 -0500 Subject: [PATCH 04/13] complete func conversion - don't know if it's the best way though --- .../squin/rewrite/qasm2/func_to_squin.py | 33 ++++++++++++------- test/squin/passes/test_qasm2_to_squin.py | 13 +++++--- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/bloqade/squin/rewrite/qasm2/func_to_squin.py b/src/bloqade/squin/rewrite/qasm2/func_to_squin.py index 55ca9673..cf16a4f6 100644 --- a/src/bloqade/squin/rewrite/qasm2/func_to_squin.py +++ b/src/bloqade/squin/rewrite/qasm2/func_to_squin.py @@ -1,25 +1,36 @@ from kirin import ir +from kirin.rewrite import Walk, Chain from kirin.dialects import func from kirin.rewrite.abc import RewriteRule, RewriteResult +from .. import qasm2 as qasm2_rules + class QASM2FuncToSquin(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + rewrite_result = RewriteResult() if not isinstance(node, func.Invoke): - return RewriteResult() + return rewrite_result callee = node.callee - - return self.rewrite_Region(callee.callable_region) - - def rewrite_Region(self, region: ir.Region) -> RewriteResult: - - rewrite_result = RewriteResult() - - for stmt in list(region.walk()): - result = self.rewrite_Statement(stmt) - rewrite_result = rewrite_result.join(result) + region = callee.callable_region + + for stmt in region.walk(): + rewrite_result = ( + Walk( + Chain( + qasm2_rules.QASM2FuncToSquin(), + qasm2_rules.QASM2ExprToSquin(), + qasm2_rules.QASM2CoreToSquin(), + qasm2_rules.QASM2UOPToSquin(), + qasm2_rules.QASM2GlobParallelToSquin(), + qasm2_rules.QASM2NoiseToSquin(), + ) + ) + .rewrite(stmt) + .join(rewrite_result) + ) return rewrite_result diff --git a/test/squin/passes/test_qasm2_to_squin.py b/test/squin/passes/test_qasm2_to_squin.py index 01e7fea6..231ccef6 100644 --- a/test/squin/passes/test_qasm2_to_squin.py +++ b/test/squin/passes/test_qasm2_to_squin.py @@ -321,9 +321,15 @@ def global_parallel_program(): def test_func(): + @qasm2.main + def another_sub_kernel(q): + qasm2.x(q) + return + @qasm2.main def sub_kernel(ctrl, target): qasm2.cx(ctrl, target) + another_sub_kernel(ctrl) return @qasm2.main @@ -332,10 +338,9 @@ def main_kernel(): sub_kernel(q[0], q[1]) return q - main_kernel.print() QASM2ToSquin(dialects=main_kernel.dialects)(main_kernel) AggressiveUnroll(dialects=main_kernel.dialects).fixpoint(main_kernel) - main_kernel.print() - -test_func() + actual_stmts = [type(stmt) for stmt in main_kernel.callable_region.walk()] + assert actual_stmts.count(squin.gate.stmts.CX) == 1 + assert actual_stmts.count(squin.gate.stmts.X) == 1 From 6552782686dc2a53c63c629f81764d8a3358615a Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 4 Dec 2025 09:16:27 -0500 Subject: [PATCH 05/13] dialect handover --- src/bloqade/squin/passes/qasm2_to_squin.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/bloqade/squin/passes/qasm2_to_squin.py b/src/bloqade/squin/passes/qasm2_to_squin.py index fa31e8e3..64c09a3d 100644 --- a/src/bloqade/squin/passes/qasm2_to_squin.py +++ b/src/bloqade/squin/passes/qasm2_to_squin.py @@ -34,17 +34,18 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: ) ).rewrite(mt.code) + # kernel should be entirely in squin dialect now + mt.dialects = squin.kernel + # the rest is taken from the squin kernel - rewrite_result = Fold(dialects=squin.kernel).fixpoint(mt) + rewrite_result = Fold(dialects=mt.dialects).fixpoint(mt) rewrite_result = ( - TypeInfer(dialects=squin.kernel).unsafe_run(mt).join(rewrite_result) + TypeInfer(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) ) rewrite_result = ( - IListDesugar(dialects=squin.kernel).unsafe_run(mt).join(rewrite_result) + IListDesugar(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) ) - TypeInfer(dialects=squin.kernel).unsafe_run(mt).join(rewrite_result) - - mt.dialects = squin.kernel + TypeInfer(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) return rewrite_result From d4f46bb70093beec30eba305aee56622a7524682 Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 4 Dec 2025 09:53:11 -0500 Subject: [PATCH 06/13] dialect handover --- src/bloqade/squin/passes/qasm2_to_squin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bloqade/squin/passes/qasm2_to_squin.py b/src/bloqade/squin/passes/qasm2_to_squin.py index 64c09a3d..2922ae29 100644 --- a/src/bloqade/squin/passes/qasm2_to_squin.py +++ b/src/bloqade/squin/passes/qasm2_to_squin.py @@ -38,7 +38,6 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: mt.dialects = squin.kernel # the rest is taken from the squin kernel - rewrite_result = Fold(dialects=mt.dialects).fixpoint(mt) rewrite_result = ( TypeInfer(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) From cd53e5dc49deab0273f4347f3012499b6d8f05d8 Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 5 Dec 2025 09:30:37 -0500 Subject: [PATCH 07/13] use callgraph pass instead of nesting rewrite rule applications --- .../squin/passes/qasm2_gate_func_to_squin.py | 52 +++++++++++++++++++ src/bloqade/squin/passes/qasm2_to_squin.py | 11 +++- src/bloqade/squin/rewrite/qasm2/__init__.py | 1 - .../squin/rewrite/qasm2/func_to_squin.py | 36 ------------- 4 files changed, 61 insertions(+), 39 deletions(-) create mode 100644 src/bloqade/squin/passes/qasm2_gate_func_to_squin.py delete mode 100644 src/bloqade/squin/rewrite/qasm2/func_to_squin.py diff --git a/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py b/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py new file mode 100644 index 00000000..2d6721c3 --- /dev/null +++ b/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py @@ -0,0 +1,52 @@ +from kirin import ir, passes +from kirin.rewrite import Walk, Chain +from kirin.dialects import func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade.rewrite.passes import CallGraphPass + +from ..rewrite import qasm2 as qasm2_rule + + +class QASM2GateFuncToKirinFunc(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + from bloqade.qasm2.dialects.expr.stmts import GateFunction + + if not isinstance(node, GateFunction): + return RewriteResult() + + kirin_func = func.Function( + sym_name=node.sym_name, + signature=node.signature, + body=node.body, + ) + node.replace_by(kirin_func) + + return RewriteResult(has_done_something=True) + + +class QASM2GateFuncToSquinPass(passes.Pass): + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + convert_to_kirin_func = CallGraphPass( + dialects=mt.dialects, rule=Walk(QASM2GateFuncToKirinFunc()) + ) + rewrite_result = convert_to_kirin_func(mt) + + combined_qasm2_rules = Walk( + Chain( + qasm2_rule.QASM2ExprToSquin(), + qasm2_rule.QASM2CoreToSquin(), + qasm2_rule.QASM2UOPToSquin(), + qasm2_rule.QASM2NoiseToSquin(), + qasm2_rule.QASM2GlobParallelToSquin(), + ) + ) + + body_conversion_pass = CallGraphPass( + dialects=mt.dialects, rule=combined_qasm2_rules + ) + rewrite_result = body_conversion_pass(mt).join(rewrite_result) + + return rewrite_result diff --git a/src/bloqade/squin/passes/qasm2_to_squin.py b/src/bloqade/squin/passes/qasm2_to_squin.py index 2922ae29..0bf134a4 100644 --- a/src/bloqade/squin/passes/qasm2_to_squin.py +++ b/src/bloqade/squin/passes/qasm2_to_squin.py @@ -11,11 +11,12 @@ QASM2UOPToSquin, QASM2CoreToSquin, QASM2ExprToSquin, - QASM2FuncToSquin, QASM2NoiseToSquin, QASM2GlobParallelToSquin, ) +from .qasm2_gate_func_to_squin import QASM2GateFuncToSquinPass + @dataclass class QASM2ToSquin(Pass): @@ -25,7 +26,6 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: # rewrite all QASM2 to squin first rewrite_result = Walk( Chain( - QASM2FuncToSquin(), QASM2ExprToSquin(), QASM2CoreToSquin(), QASM2UOPToSquin(), @@ -34,6 +34,13 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: ) ).rewrite(mt.code) + # go into subkernels + rewrite_result = ( + QASM2GateFuncToSquinPass(dialects=mt.dialects) + .unsafe_run(mt) + .join(rewrite_result) + ) + # kernel should be entirely in squin dialect now mt.dialects = squin.kernel diff --git a/src/bloqade/squin/rewrite/qasm2/__init__.py b/src/bloqade/squin/rewrite/qasm2/__init__.py index 7193afeb..c2355fde 100644 --- a/src/bloqade/squin/rewrite/qasm2/__init__.py +++ b/src/bloqade/squin/rewrite/qasm2/__init__.py @@ -1,6 +1,5 @@ from .uop_to_squin import QASM2UOPToSquin as QASM2UOPToSquin from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin from .expr_to_squin import QASM2ExprToSquin as QASM2ExprToSquin -from .func_to_squin import QASM2FuncToSquin as QASM2FuncToSquin from .noise_to_squin import QASM2NoiseToSquin as QASM2NoiseToSquin from .glob_parallel_to_squin import QASM2GlobParallelToSquin as QASM2GlobParallelToSquin diff --git a/src/bloqade/squin/rewrite/qasm2/func_to_squin.py b/src/bloqade/squin/rewrite/qasm2/func_to_squin.py deleted file mode 100644 index cf16a4f6..00000000 --- a/src/bloqade/squin/rewrite/qasm2/func_to_squin.py +++ /dev/null @@ -1,36 +0,0 @@ -from kirin import ir -from kirin.rewrite import Walk, Chain -from kirin.dialects import func -from kirin.rewrite.abc import RewriteRule, RewriteResult - -from .. import qasm2 as qasm2_rules - - -class QASM2FuncToSquin(RewriteRule): - - def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - - rewrite_result = RewriteResult() - if not isinstance(node, func.Invoke): - return rewrite_result - - callee = node.callee - region = callee.callable_region - - for stmt in region.walk(): - rewrite_result = ( - Walk( - Chain( - qasm2_rules.QASM2FuncToSquin(), - qasm2_rules.QASM2ExprToSquin(), - qasm2_rules.QASM2CoreToSquin(), - qasm2_rules.QASM2UOPToSquin(), - qasm2_rules.QASM2GlobParallelToSquin(), - qasm2_rules.QASM2NoiseToSquin(), - ) - ) - .rewrite(stmt) - .join(rewrite_result) - ) - - return rewrite_result From 59dfcb0b486c5dd33c155daabb9d64c53958c3d2 Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 5 Dec 2025 09:47:49 -0500 Subject: [PATCH 08/13] add missing slots arg --- src/bloqade/squin/passes/qasm2_gate_func_to_squin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py b/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py index 2d6721c3..89d437d7 100644 --- a/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py +++ b/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py @@ -20,6 +20,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: sym_name=node.sym_name, signature=node.signature, body=node.body, + slots=node.slots, ) node.replace_by(kirin_func) From 99e66af93a52a48eb123c4d9ef821078da86d2de Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 5 Dec 2025 11:09:36 -0500 Subject: [PATCH 09/13] get rid of redundant rewrite of QASM2 const/expr to py const/expr - already exists --- .../squin/passes/qasm2_gate_func_to_squin.py | 3 +- src/bloqade/squin/passes/qasm2_to_squin.py | 7 +- src/bloqade/squin/rewrite/qasm2/__init__.py | 1 - .../squin/rewrite/qasm2/expr_to_squin.py | 113 ------------------ .../measure_id/test_refactored_measure_id.py | 32 +++++ test/squin/passes/test_qasm2_to_squin.py | 73 +---------- 6 files changed, 40 insertions(+), 189 deletions(-) delete mode 100644 src/bloqade/squin/rewrite/qasm2/expr_to_squin.py create mode 100644 test/analysis/measure_id/test_refactored_measure_id.py diff --git a/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py b/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py index 89d437d7..3d63e671 100644 --- a/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py +++ b/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py @@ -4,6 +4,7 @@ from kirin.rewrite.abc import RewriteRule, RewriteResult from bloqade.rewrite.passes import CallGraphPass +from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule from ..rewrite import qasm2 as qasm2_rule @@ -37,7 +38,7 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: combined_qasm2_rules = Walk( Chain( - qasm2_rule.QASM2ExprToSquin(), + QASM2ToPyRule(), qasm2_rule.QASM2CoreToSquin(), qasm2_rule.QASM2UOPToSquin(), qasm2_rule.QASM2NoiseToSquin(), diff --git a/src/bloqade/squin/passes/qasm2_to_squin.py b/src/bloqade/squin/passes/qasm2_to_squin.py index 0bf134a4..4f2d1108 100644 --- a/src/bloqade/squin/passes/qasm2_to_squin.py +++ b/src/bloqade/squin/passes/qasm2_to_squin.py @@ -10,11 +10,14 @@ from bloqade.squin.rewrite.qasm2 import ( QASM2UOPToSquin, QASM2CoreToSquin, - QASM2ExprToSquin, QASM2NoiseToSquin, QASM2GlobParallelToSquin, ) +# There's a QASM2Py pass that only applies an _QASM2Py rewrite rule, +# I just want the rule here. +from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule + from .qasm2_gate_func_to_squin import QASM2GateFuncToSquinPass @@ -26,7 +29,7 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: # rewrite all QASM2 to squin first rewrite_result = Walk( Chain( - QASM2ExprToSquin(), + QASM2ToPyRule(), QASM2CoreToSquin(), QASM2UOPToSquin(), QASM2GlobParallelToSquin(), diff --git a/src/bloqade/squin/rewrite/qasm2/__init__.py b/src/bloqade/squin/rewrite/qasm2/__init__.py index c2355fde..abcece24 100644 --- a/src/bloqade/squin/rewrite/qasm2/__init__.py +++ b/src/bloqade/squin/rewrite/qasm2/__init__.py @@ -1,5 +1,4 @@ from .uop_to_squin import QASM2UOPToSquin as QASM2UOPToSquin from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin -from .expr_to_squin import QASM2ExprToSquin as QASM2ExprToSquin from .noise_to_squin import QASM2NoiseToSquin as QASM2NoiseToSquin from .glob_parallel_to_squin import QASM2GlobParallelToSquin as QASM2GlobParallelToSquin diff --git a/src/bloqade/squin/rewrite/qasm2/expr_to_squin.py b/src/bloqade/squin/rewrite/qasm2/expr_to_squin.py deleted file mode 100644 index 5b913562..00000000 --- a/src/bloqade/squin/rewrite/qasm2/expr_to_squin.py +++ /dev/null @@ -1,113 +0,0 @@ -from math import pi - -from kirin import ir -from kirin.dialects import py, math as kirin_math -from kirin.rewrite.abc import RewriteRule, RewriteResult - -from bloqade.qasm2.dialects.expr import stmts as expr_stmts - -qasm2_binops = [] - - -class QASM2ExprToSquin(RewriteRule): - - def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - - match node: - case expr_stmts.ConstInt() | expr_stmts.ConstFloat(): - return self.rewrite_Const(node) - case expr_stmts.ConstPI(): - return self.rewrite_PI(node) - case ( - expr_stmts.Mul() - | expr_stmts.Add() - | expr_stmts.Sub() - | expr_stmts.Div() - | expr_stmts.Pow() - ): - return self.rewrite_BinOp(node) - case ( - expr_stmts.Neg() - | expr_stmts.Sin() - | expr_stmts.Cos() - | expr_stmts.Tan() - | expr_stmts.Exp() - | expr_stmts.Log() - | expr_stmts.Sqrt() - ): - return self.rewrite_UnaryOp(node) - case _: - return RewriteResult() - - def rewrite_Const( - self, stmt: expr_stmts.ConstInt | expr_stmts.ConstFloat - ) -> RewriteResult: - - py_const = py.Constant(value=stmt.value) - stmt.replace_by(py_const) - return RewriteResult(has_done_something=True) - - def rewrite_PI(self, stmt: expr_stmts.ConstPI) -> RewriteResult: - - py_const = py.Constant(value=pi) - stmt.replace_by(py_const) - return RewriteResult(has_done_something=True) - - def rewrite_BinOp(self, stmt: ir.Statement) -> RewriteResult: - - match stmt: - case expr_stmts.Mul(): - op = py.binop.Mult - case expr_stmts.Add(): - op = py.binop.Add - case expr_stmts.Sub(): - op = py.binop.Sub - case expr_stmts.Div(): - op = py.binop.Div - case expr_stmts.Pow(): - op = py.binop.Pow - case _: - return RewriteResult() - - lhs = stmt.lhs - rhs = stmt.rhs - binop_expr = op(lhs=lhs, rhs=rhs) - stmt.replace_by(binop_expr) - return RewriteResult(has_done_something=True) - - def rewrite_UnaryOp( - self, - stmt: ( - expr_stmts.Neg - | expr_stmts.Sin - | expr_stmts.Cos - | expr_stmts.Tan - | expr_stmts.Exp - | expr_stmts.Log - | expr_stmts.Sqrt - ), - ) -> RewriteResult: - - match stmt: - case expr_stmts.Neg(): - op = py.unary.stmts.USub - case expr_stmts.Sin(): - op = kirin_math.stmts.sin - case expr_stmts.Cos(): - op = kirin_math.stmts.cos - case expr_stmts.Tan(): - op = kirin_math.stmts.tan - case expr_stmts.Exp(): - op = kirin_math.stmts.exp - case expr_stmts.Log(): - op = kirin_math.stmts.log2 - case expr_stmts.Sqrt(): - op = kirin_math.stmts.sqrt - case _: - return RewriteResult() - - value = stmt.value - unary_expr = op(value) - - stmt.replace_by(unary_expr) - return RewriteResult(has_done_something=True) diff --git a/test/analysis/measure_id/test_refactored_measure_id.py b/test/analysis/measure_id/test_refactored_measure_id.py new file mode 100644 index 00000000..107ed1cf --- /dev/null +++ b/test/analysis/measure_id/test_refactored_measure_id.py @@ -0,0 +1,32 @@ +import io + +from kirin import ir + +from bloqade import stim, squin +from bloqade.stim.emit import EmitStimMain +from bloqade.stim.passes import SquinToStimPass + + +def codegen(mt: ir.Method): + # method should not have any arguments! + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) + emit.initialize() + emit.run(mt) + return buf.getvalue().strip() + + +@squin.kernel +def test_simple_linear(): + + qs = squin.qalloc(4) + m0 = squin.broadcast.measure(qs) + squin.set_detector([m0[0], m0[1]], coordinates=[0, 0]) + m1 = squin.broadcast.measure(qs) + squin.set_detector([m1[0], m1[1]], coordinates=[1, 1]) + + +test_simple_linear.print() +SquinToStimPass(dialects=test_simple_linear.dialects)(test_simple_linear) +test_simple_linear.print() +print(codegen(test_simple_linear)) diff --git a/test/squin/passes/test_qasm2_to_squin.py b/test/squin/passes/test_qasm2_to_squin.py index 231ccef6..74ecb035 100644 --- a/test/squin/passes/test_qasm2_to_squin.py +++ b/test/squin/passes/test_qasm2_to_squin.py @@ -1,8 +1,7 @@ import math from kirin import types as kirin_types -from kirin.rewrite import Walk -from kirin.dialects import py, func, math as kirin_math +from kirin.dialects import py from kirin.dialects.ilist import IListType from bloqade import qasm2, squin @@ -11,79 +10,9 @@ from bloqade.squin.passes import QASM2ToSquin from bloqade.rewrite.passes import AggressiveUnroll from bloqade.analysis.address import AddressAnalysis -from bloqade.squin.rewrite.qasm2 import ( - QASM2ExprToSquin, -) from bloqade.analysis.address.lattice import AddressReg -def test_expr_rewrite(): - - @qasm2.main - def expr_program(): - # constants - x = 0 # noqa: F841 - # ConstPI only added from lowering - y = qasm2.dialects.expr.stmts.ConstPI() # noqa: F841 - z = -1.75 # noqa: F841 - # binary ops - a = 1 + 1 # noqa: F841 - b = 2 * 2 # noqa: F841 - c = 3 - 3 # noqa: F841 - d = 4 / 4 # noqa: F841 - e = 5**2 # noqa: F841 - - # math - a = 0.2 - qasm2.sin(a) - qasm2.cos(a) - qasm2.tan(a) - qasm2.exp(a) - qasm2.ln(a) - qasm2.sqrt(a) - return - - expr_program.print() - - Walk(QASM2ExprToSquin()).rewrite(expr_program.code) - - expr_program.print() - - actual_stmt_sequence = list(expr_program.callable_region.walk()) - - def is_pi_const(stmt: py.Constant): - return isinstance(stmt, py.Constant) and math.isclose(stmt.value.data, math.pi) - - assert any(is_pi_const(stmt) for stmt in actual_stmt_sequence) - - assert qasm2.expr.ConstFloat not in actual_stmt_sequence - assert qasm2.expr.ConstInt not in actual_stmt_sequence - assert qasm2.expr.ConstPI not in actual_stmt_sequence - - no_const_actual_sequence = [ - type(stmt) - for stmt in actual_stmt_sequence - if not isinstance(stmt, (py.Constant, func.ConstantNone, func.Return)) - ] - - expected_stmt_sequence = [ - py.unary.stmts.USub, - py.binop.Add, - py.binop.Mult, - py.binop.Sub, - py.binop.Div, - py.binop.Pow, - kirin_math.stmts.sin, - kirin_math.stmts.cos, - kirin_math.stmts.tan, - kirin_math.stmts.exp, - kirin_math.stmts.log2, - kirin_math.stmts.sqrt, - ] - - assert no_const_actual_sequence == expected_stmt_sequence - - def test_qasm2_core(): @qasm2.main(fold=False) From 045fa00ad1c4d6996db7139e105d56d628daf757 Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 5 Dec 2025 15:48:18 -0500 Subject: [PATCH 10/13] broke everything up into smaller rulesgst --- src/bloqade/qasm2/dialects/uop/stmts.py | 2 +- .../squin/passes/qasm2_gate_func_to_squin.py | 7 +- src/bloqade/squin/passes/qasm2_to_squin.py | 10 +- src/bloqade/squin/rewrite/qasm2/__init__.py | 7 +- .../squin/rewrite/qasm2/core_to_squin.py | 43 +++-- .../rewrite/qasm2/glob_parallel_to_squin.py | 59 +++---- .../squin/rewrite/qasm2/id_to_squin.py | 15 ++ .../squin/rewrite/qasm2/noise_to_squin.py | 87 ++++----- .../qasm2/parametrized_uop_1q_to_squin.py | 46 +++++ .../squin/rewrite/qasm2/uop_1q_to_squin.py | 33 ++++ .../squin/rewrite/qasm2/uop_2q_to_squin.py | 28 +++ .../squin/rewrite/qasm2/uop_to_squin.py | 166 ------------------ src/bloqade/squin/rewrite/qasm2/util.py | 15 ++ test/squin/passes/test_qasm2_to_squin.py | 68 +++++-- 14 files changed, 289 insertions(+), 297 deletions(-) create mode 100644 src/bloqade/squin/rewrite/qasm2/id_to_squin.py create mode 100644 src/bloqade/squin/rewrite/qasm2/parametrized_uop_1q_to_squin.py create mode 100644 src/bloqade/squin/rewrite/qasm2/uop_1q_to_squin.py create mode 100644 src/bloqade/squin/rewrite/qasm2/uop_2q_to_squin.py delete mode 100644 src/bloqade/squin/rewrite/qasm2/uop_to_squin.py create mode 100644 src/bloqade/squin/rewrite/qasm2/util.py diff --git a/src/bloqade/qasm2/dialects/uop/stmts.py b/src/bloqade/qasm2/dialects/uop/stmts.py index 755b50e5..55f55610 100644 --- a/src/bloqade/qasm2/dialects/uop/stmts.py +++ b/src/bloqade/qasm2/dialects/uop/stmts.py @@ -28,7 +28,7 @@ class TwoQubitCtrlGate(ir.Statement): @statement(dialect=dialect) class CX(TwoQubitCtrlGate): - """Alias for the CNOT or CH gate operations.""" + """Alias for the CNOT or CX gate operations.""" name = "CX" # Note this is capitalized diff --git a/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py b/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py index 3d63e671..613616d2 100644 --- a/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py +++ b/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py @@ -40,9 +40,12 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: Chain( QASM2ToPyRule(), qasm2_rule.QASM2CoreToSquin(), - qasm2_rule.QASM2UOPToSquin(), - qasm2_rule.QASM2NoiseToSquin(), qasm2_rule.QASM2GlobParallelToSquin(), + qasm2_rule.QASM2NoiseToSquin(), + qasm2_rule.QASM2IdToSquin(), + qasm2_rule.QASM2UOp1QToSquin(), + qasm2_rule.QASM2ParametrizedUOp1QToSquin(), + qasm2_rule.QASM2UOp2QToSquin(), ) ) diff --git a/src/bloqade/squin/passes/qasm2_to_squin.py b/src/bloqade/squin/passes/qasm2_to_squin.py index 4f2d1108..d23605d7 100644 --- a/src/bloqade/squin/passes/qasm2_to_squin.py +++ b/src/bloqade/squin/passes/qasm2_to_squin.py @@ -8,10 +8,13 @@ from bloqade import squin from bloqade.squin.rewrite.qasm2 import ( - QASM2UOPToSquin, + QASM2IdToSquin, QASM2CoreToSquin, QASM2NoiseToSquin, + QASM2UOp1QToSquin, + QASM2UOp2QToSquin, QASM2GlobParallelToSquin, + QASM2ParametrizedUOp1QToSquin, ) # There's a QASM2Py pass that only applies an _QASM2Py rewrite rule, @@ -31,9 +34,12 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult: Chain( QASM2ToPyRule(), QASM2CoreToSquin(), - QASM2UOPToSquin(), QASM2GlobParallelToSquin(), QASM2NoiseToSquin(), + QASM2IdToSquin(), + QASM2UOp1QToSquin(), + QASM2ParametrizedUOp1QToSquin(), + QASM2UOp2QToSquin(), ) ).rewrite(mt.code) diff --git a/src/bloqade/squin/rewrite/qasm2/__init__.py b/src/bloqade/squin/rewrite/qasm2/__init__.py index abcece24..0cd85c7c 100644 --- a/src/bloqade/squin/rewrite/qasm2/__init__.py +++ b/src/bloqade/squin/rewrite/qasm2/__init__.py @@ -1,4 +1,9 @@ -from .uop_to_squin import QASM2UOPToSquin as QASM2UOPToSquin +from .id_to_squin import QASM2IdToSquin as QASM2IdToSquin from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin from .noise_to_squin import QASM2NoiseToSquin as QASM2NoiseToSquin +from .uop_1q_to_squin import QASM2UOp1QToSquin as QASM2UOp1QToSquin +from .uop_2q_to_squin import QASM2UOp2QToSquin as QASM2UOp2QToSquin from .glob_parallel_to_squin import QASM2GlobParallelToSquin as QASM2GlobParallelToSquin +from .parametrized_uop_1q_to_squin import ( + QASM2ParametrizedUOp1QToSquin as QASM2ParametrizedUOp1QToSquin, +) diff --git a/src/bloqade/squin/rewrite/qasm2/core_to_squin.py b/src/bloqade/squin/rewrite/qasm2/core_to_squin.py index 27ea5435..29de3f48 100644 --- a/src/bloqade/squin/rewrite/qasm2/core_to_squin.py +++ b/src/bloqade/squin/rewrite/qasm2/core_to_squin.py @@ -5,29 +5,34 @@ from bloqade import squin from bloqade.qasm2.dialects.core import stmts as core_stmts +CORE_TO_SQUIN_MAP = { + core_stmts.QRegNew: squin.qubit.qalloc, + core_stmts.Reset: squin.qubit.reset, +} + class QASM2CoreToSquin(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - match node: - case core_stmts.QRegNew(n_qubits=n_qubits): - qalloc_invoke_stmt = func.Invoke( - callee=squin.qubit.qalloc, inputs=(n_qubits,) - ) - node.replace_by(qalloc_invoke_stmt) - case core_stmts.Reset(qarg=qarg): - reset_invoke_stmt = func.Invoke( - callee=squin.qubit.reset, inputs=(qarg,) - ) - node.replace_by(reset_invoke_stmt) - case core_stmts.QRegGet(reg=reg, idx=idx): - get_item_stmt = py.GetItem( - obj=reg, - index=idx, - ) - node.replace_by(get_item_stmt) - case _: - return RewriteResult() + if isinstance(node, core_stmts.QRegGet): + py_get_item = py.GetItem( + obj=node.reg, + index=node.idx, + ) + node.replace_by(py_get_item) + return RewriteResult(has_done_something=True) + + if isinstance(node, core_stmts.QRegNew): + args = (node.n_qubits,) + elif isinstance(node, core_stmts.Reset): + args = (node.qarg,) + else: + return RewriteResult() + new_stmt = func.Invoke( + callee=CORE_TO_SQUIN_MAP[type(node)], + inputs=args, + ) + node.replace_by(new_stmt) return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py b/src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py index 65322d00..9fe52471 100644 --- a/src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py +++ b/src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py @@ -5,47 +5,30 @@ from bloqade import squin from bloqade.qasm2.dialects import glob, parallel +GLOBAL_PARALLEL_TO_SQUIN_MAP = { + glob.UGate: squin.broadcast.u3, + parallel.UGate: squin.broadcast.u3, + parallel.RZ: squin.broadcast.rz, +} + class QASM2GlobParallelToSquin(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - match node: - case glob.UGate() | parallel.UGate() | parallel.RZ(): - return self.rewrite_1q_gates(node) - case _: - return RewriteResult() - - return RewriteResult(has_done_something=True) - - def rewrite_1q_gates( - self, stmt: glob.UGate | parallel.UGate | parallel.RZ - ) -> RewriteResult: - - match stmt: - case glob.UGate(theta=theta, phi=phi, lam=lam) | parallel.UGate( - theta=theta, phi=phi, lam=lam - ): - # ever so slight naming difference, - # exists because intended semantics are different - match stmt: - case glob.UGate(): - qargs = stmt.registers - case parallel.UGate(): - qargs = stmt.qargs - - invoke_u_broadcast_stmt = func.Invoke( - callee=squin.broadcast.u3, - inputs=(theta, phi, lam, qargs), - ) - stmt.replace_by(invoke_u_broadcast_stmt) - case parallel.RZ(theta=theta, qargs=qargs): - invoke_rz_broadcast_stmt = func.Invoke( - callee=squin.broadcast.rz, - inputs=(theta, qargs), - ) - stmt.replace_by(invoke_rz_broadcast_stmt) - case _: - return RewriteResult() - + if isinstance(node, glob.UGate): + args = (node.theta, node.phi, node.lam, node.registers) + elif isinstance(node, parallel.UGate): + args = (node.theta, node.phi, node.lam, node.qargs) + elif isinstance(node, parallel.RZ): + args = (node.theta, node.qargs) + else: + return RewriteResult() + + squin_equivalent_stmt = GLOBAL_PARALLEL_TO_SQUIN_MAP[type(node)] + invoke_stmt = func.Invoke( + callee=squin_equivalent_stmt, + inputs=args, + ) + node.replace_by(invoke_stmt) return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/id_to_squin.py b/src/bloqade/squin/rewrite/qasm2/id_to_squin.py new file mode 100644 index 00000000..b5f12f57 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/id_to_squin.py @@ -0,0 +1,15 @@ +from kirin import ir +from kirin.rewrite.abc import RewriteRule, RewriteResult + +import bloqade.qasm2.dialects.uop.stmts as uop_stmts + + +class QASM2IdToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + if not isinstance(node, uop_stmts.Id): + return RewriteResult() + + node.delete() + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py b/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py index 625d9a29..90b03553 100644 --- a/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py +++ b/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py @@ -1,67 +1,43 @@ from kirin import ir -from kirin.dialects import py, func +from kirin.dialects import func from kirin.rewrite.abc import RewriteRule, RewriteResult from bloqade import squin from bloqade.qasm2.dialects.noise import stmts as noise_stmts +from .util import num_to_py_constant -class QASM2NoiseToSquin(RewriteRule): - - def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - - match node: - case noise_stmts.AtomLossChannel(): - return self.rewrite_AtomLossChannel(node) - case noise_stmts.PauliChannel(): - return self.rewrite_PauliChannel(node) - case noise_stmts.CZPauliChannel(): - return self.rewrite_CZPauliChannel(node) - case _: - return RewriteResult() - - return RewriteResult() +NOISE_TO_SQUIN_MAP = { + noise_stmts.AtomLossChannel: squin.broadcast.qubit_loss, + noise_stmts.PauliChannel: squin.broadcast.single_qubit_pauli_channel, +} - def rewrite_AtomLossChannel( - self, stmt: noise_stmts.AtomLossChannel - ) -> RewriteResult: - qargs = stmt.qargs - # this is a raw float, not in SSA form yet! - prob = stmt.prob - prob_stmt = py.Constant(value=prob) - prob_stmt.insert_before(stmt) - - invoke_loss_stmt = func.Invoke( - callee=squin.broadcast.qubit_loss, - inputs=(prob_stmt.result, qargs), - ) - - stmt.replace_by(invoke_loss_stmt) - - return RewriteResult(has_done_something=True) - - def rewrite_PauliChannel(self, stmt: noise_stmts.PauliChannel) -> RewriteResult: - - qargs = stmt.qargs - p_x = stmt.px - p_y = stmt.py - p_z = stmt.pz - - probs = [p_x, p_y, p_z] - probs_ssas = [] +class QASM2NoiseToSquin(RewriteRule): - for prob in probs: - prob_stmt = py.Constant(value=prob) - prob_stmt.insert_before(stmt) - probs_ssas.append(prob_stmt.result) + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - invoke_pauli_channel_stmt = func.Invoke( - callee=squin.broadcast.single_qubit_pauli_channel, - inputs=(*probs_ssas, qargs), + if isinstance(node, noise_stmts.AtomLossChannel): + qargs = node.qargs + prob = node.prob + prob_ssas = num_to_py_constant([prob], stmt_to_insert_before=node) + elif isinstance(node, noise_stmts.PauliChannel): + qargs = node.qargs + p_x = node.px + p_y = node.py + p_z = node.pz + prob_ssas = num_to_py_constant([p_x, p_y, p_z], stmt_to_insert_before=node) + elif isinstance(node, noise_stmts.CZPauliChannel): + return self.rewrite_CZPauliChannel(node) + else: + return RewriteResult() + + squin_noise_stmt = NOISE_TO_SQUIN_MAP[type(node)] + invoke_stmt = func.Invoke( + callee=squin_noise_stmt, + inputs=(*prob_ssas, qargs), ) - - stmt.replace_by(invoke_pauli_channel_stmt) + node.replace_by(invoke_stmt) return RewriteResult(has_done_something=True) def rewrite_CZPauliChannel(self, stmt: noise_stmts.CZPauliChannel) -> RewriteResult: @@ -78,11 +54,8 @@ def rewrite_CZPauliChannel(self, stmt: noise_stmts.CZPauliChannel) -> RewriteRes error_probs = [px_ctrl, py_ctrl, pz_ctrl, px_qarg, py_qarg, pz_qarg] # first half of entries for control qubits, other half for targets - error_prob_ssas = [] - for error_prob in error_probs: - error_prob_stmt = py.Constant(value=error_prob) - error_prob_stmt.insert_before(stmt) - error_prob_ssas.append(error_prob_stmt.result) + + error_prob_ssas = num_to_py_constant(error_probs, stmt_to_insert_before=stmt) ctrl_pauli_channel_invoke = func.Invoke( callee=squin.broadcast.single_qubit_pauli_channel, diff --git a/src/bloqade/squin/rewrite/qasm2/parametrized_uop_1q_to_squin.py b/src/bloqade/squin/rewrite/qasm2/parametrized_uop_1q_to_squin.py new file mode 100644 index 00000000..ffca71e0 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/parametrized_uop_1q_to_squin.py @@ -0,0 +1,46 @@ +from math import pi + +from kirin import ir +from kirin.dialects import py, func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.qasm2.dialects.uop import stmts as uop_stmts + +PARAMETRIZED_1Q_GATES_TO_SQUIN_MAP = { + uop_stmts.UGate: squin.u3, + uop_stmts.U1: squin.u3, + uop_stmts.U2: squin.u3, + uop_stmts.RZ: squin.rz, + uop_stmts.RX: squin.rx, + uop_stmts.RY: squin.ry, +} + + +class QASM2ParametrizedUOp1QToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + if isinstance(node, (uop_stmts.RX, uop_stmts.RY, uop_stmts.RZ)): + args = (node.theta, node.qarg) + elif isinstance(node, (uop_stmts.UGate)): + args = (node.theta, node.phi, node.lam, node.qarg) + elif isinstance(node, (uop_stmts.U1)): + zero_stmt = py.Constant(value=0.0) + zero_stmt.insert_before(node) + args = (zero_stmt.result, zero_stmt.result, node.lam, node.qarg) + elif isinstance(node, (uop_stmts.U2)): + half_pi_stmt = py.Constant(value=pi / 2) + half_pi_stmt.insert_before(node) + args = (half_pi_stmt.result, node.phi, node.lam, node.qarg) + else: + return RewriteResult() + + squin_equivalent_stmt = PARAMETRIZED_1Q_GATES_TO_SQUIN_MAP[type(node)] + invoke_stmt = func.Invoke( + callee=squin_equivalent_stmt, + inputs=args, + ) + node.replace_by(invoke_stmt) + + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/uop_1q_to_squin.py b/src/bloqade/squin/rewrite/qasm2/uop_1q_to_squin.py new file mode 100644 index 00000000..06e179d0 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/uop_1q_to_squin.py @@ -0,0 +1,33 @@ +from kirin import ir +from kirin.dialects import func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.qasm2.dialects.uop import stmts as uop_stmts + +ONE_Q_GATES_TO_SQUIN_MAP = { + uop_stmts.X: squin.x, + uop_stmts.Y: squin.y, + uop_stmts.Z: squin.z, + uop_stmts.H: squin.h, + uop_stmts.S: squin.s, + uop_stmts.T: squin.t, + uop_stmts.SX: squin.sqrt_x, +} + + +class QASM2UOp1QToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + squin_1q_gate = ONE_Q_GATES_TO_SQUIN_MAP.get(type(node)) + if squin_1q_gate is None: + return RewriteResult() + + invoke_stmt = func.Invoke( + callee=squin_1q_gate, + inputs=(node.qarg,), + ) + node.replace_by(invoke_stmt) + + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/uop_2q_to_squin.py b/src/bloqade/squin/rewrite/qasm2/uop_2q_to_squin.py new file mode 100644 index 00000000..5aeeb793 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/uop_2q_to_squin.py @@ -0,0 +1,28 @@ +from kirin import ir +from kirin.dialects import func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.qasm2.dialects.uop import stmts as uop_stmts + +CONTROLLED_GATES_TO_SQUIN_MAP = { + uop_stmts.CX: squin.cx, + uop_stmts.CZ: squin.cz, + uop_stmts.CY: squin.cy, +} + + +class QASM2UOp2QToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + squin_controlled_gate = CONTROLLED_GATES_TO_SQUIN_MAP.get(type(node)) + if squin_controlled_gate is None: + return RewriteResult() + + invoke_stmt = func.Invoke( + callee=squin_controlled_gate, + inputs=(node.ctrl, node.qarg), + ) + node.replace_by(invoke_stmt) + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/uop_to_squin.py b/src/bloqade/squin/rewrite/qasm2/uop_to_squin.py deleted file mode 100644 index 7a905d14..00000000 --- a/src/bloqade/squin/rewrite/qasm2/uop_to_squin.py +++ /dev/null @@ -1,166 +0,0 @@ -from math import pi - -from kirin import ir -from kirin.dialects import py, func -from kirin.rewrite.abc import RewriteRule, RewriteResult - -from bloqade import squin -from bloqade.qasm2.dialects.uop import stmts as uop_stmts - - -# assume that qasm2.core conversion has already run beforehand -class QASM2UOPToSquin(RewriteRule): - - def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - - match node: - case uop_stmts.CX() | uop_stmts.CZ() | uop_stmts.CY(): - return self.rewrite_TwoQubitCtrlGate(node) - case ( - uop_stmts.X() - | uop_stmts.Y() - | uop_stmts.Z() - | uop_stmts.H() - | uop_stmts.S() - | uop_stmts.T() - | uop_stmts.SX() - ): - return self.rewrite_SingleQubitGate_no_parameters(node) - case uop_stmts.RZ() | uop_stmts.RX() | uop_stmts.RY(): - return self.rewrite_SingleQubit_with_parameters(node) - case uop_stmts.UGate() | uop_stmts.U1() | uop_stmts.U2(): - return self.rewrite_u_gates(node) - case uop_stmts.Id(): - return self.rewrite_Id(node) - case _: - return RewriteResult() - - def rewrite_TwoQubitCtrlGate( - self, stmt: uop_stmts.CX | uop_stmts.CZ | uop_stmts.CY - ) -> RewriteResult: - - # qasm2 does not have broadcast semantics - # don't have to worry about these being lists - carg = stmt.ctrl - qarg = stmt.qarg - - match stmt: - case uop_stmts.CX(): - squin_2q_stdlib = squin.cx - case uop_stmts.CZ(): - squin_2q_stdlib = squin.cz - case uop_stmts.CY(): - squin_2q_stdlib = squin.cy - case _: - return RewriteResult() - - invoke_stmt = func.Invoke( - callee=squin_2q_stdlib, - inputs=(carg, qarg), - ) - - stmt.replace_by(invoke_stmt) - - return RewriteResult(has_done_something=True) - - def rewrite_SingleQubitGate_no_parameters( - self, - stmt: ( - uop_stmts.X - | uop_stmts.Y - | uop_stmts.Z - | uop_stmts.H - | uop_stmts.S - | uop_stmts.T - ), - ) -> RewriteResult: - - qarg = stmt.qarg - match stmt: - case uop_stmts.X(): - squin_1q_stdlib = squin.x - case uop_stmts.Y(): - squin_1q_stdlib = squin.y - case uop_stmts.Z(): - squin_1q_stdlib = squin.z - case uop_stmts.H(): - squin_1q_stdlib = squin.h - case uop_stmts.S(): - squin_1q_stdlib = squin.s - case uop_stmts.T(): - squin_1q_stdlib = squin.t - case uop_stmts.SX(): - squin_1q_stdlib = squin.sqrt_x - case _: - return RewriteResult() - - invoke_stmt = func.Invoke( - callee=squin_1q_stdlib, - inputs=(qarg,), - ) - - stmt.replace_by(invoke_stmt) - - return RewriteResult(has_done_something=True) - - def rewrite_u_gates( - self, stmt: uop_stmts.UGate | uop_stmts.U1 | uop_stmts.U2 - ) -> RewriteResult: - - match stmt: - case uop_stmts.UGate(lam=lam, phi=phi, theta=theta, qarg=qarg): - args = (theta, phi, lam, qarg) - case uop_stmts.U1(lam=lam, qarg=qarg): - zero_stmt = py.Constant(value=0.0) - zero_stmt.insert_before(stmt) - args = (zero_stmt.result, zero_stmt.result, lam, qarg) - case uop_stmts.U2(phi=phi, lam=lam, qarg=qarg): - half_pi_stmt = py.Constant(value=pi / 2) - half_pi_stmt.insert_before(stmt) - args = (half_pi_stmt.result, phi, lam, qarg) - case _: - return RewriteResult() - - invoke_stmt = func.Invoke( - callee=squin.u3, - inputs=args, - ) - - stmt.replace_by(invoke_stmt) - - return RewriteResult(has_done_something=True) - - def rewrite_SingleQubit_with_parameters( - self, stmt: uop_stmts.RZ | uop_stmts.RX | uop_stmts.RY - ) -> RewriteResult: - - qarg = stmt.qarg - theta = stmt.theta - - match stmt: - case uop_stmts.RZ(): - squin_1q_stdlib = squin.rz - case uop_stmts.RX(): - squin_1q_stdlib = squin.rx - case uop_stmts.RY(): - squin_1q_stdlib = squin.ry - case _: - return RewriteResult() - - invoke_stmt = func.Invoke( - callee=squin_1q_stdlib, - inputs=(theta, qarg), - ) - - stmt.replace_by(invoke_stmt) - - return RewriteResult(has_done_something=True) - - def rewrite_Id(self, stmt: uop_stmts.Id) -> RewriteResult: - - # Identity does not exist in squin, - # we can just remove it from the program - - stmt.delete() - - return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/util.py b/src/bloqade/squin/rewrite/qasm2/util.py new file mode 100644 index 00000000..a74528fd --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/util.py @@ -0,0 +1,15 @@ +from kirin import ir +from kirin.dialects import py + + +def num_to_py_constant( + values: list[int | float], stmt_to_insert_before: ir.Statement +) -> list[ir.SSAValue]: + + py_const_ssa_vals = [] + for value in values: + const_form = py.Constant(value=value) + const_form.insert_before(stmt_to_insert_before) + py_const_ssa_vals.append(const_form.result) + + return py_const_ssa_vals diff --git a/test/squin/passes/test_qasm2_to_squin.py b/test/squin/passes/test_qasm2_to_squin.py index 74ecb035..5d05cf32 100644 --- a/test/squin/passes/test_qasm2_to_squin.py +++ b/test/squin/passes/test_qasm2_to_squin.py @@ -1,28 +1,41 @@ import math from kirin import types as kirin_types +from kirin.passes import TypeInfer +from kirin.rewrite import Walk, Chain from kirin.dialects import py from kirin.dialects.ilist import IListType +import bloqade.squin.rewrite.qasm2 as qasm2_rules from bloqade import qasm2, squin from bloqade.qasm2 import glob, noise, parallel from bloqade.types import QubitType -from bloqade.squin.passes import QASM2ToSquin from bloqade.rewrite.passes import AggressiveUnroll from bloqade.analysis.address import AddressAnalysis +from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule from bloqade.analysis.address.lattice import AddressReg +from bloqade.squin.passes.qasm2_gate_func_to_squin import QASM2GateFuncToSquinPass def test_qasm2_core(): - @qasm2.main(fold=False) + @qasm2.main def core_kernel(): q = qasm2.qreg(5) q0 = q[0] qasm2.reset(q0) return q0 - QASM2ToSquin(dialects=squin.kernel)(core_kernel) + Walk( + Chain( + QASM2ToPyRule(), + qasm2_rules.QASM2CoreToSquin(), + ) + ).rewrite(core_kernel.code) + + TypeInfer(dialects=squin.kernel)(core_kernel) + + core_kernel.print() stmts = list(core_kernel.callable_region.walk()) get_item_stmt = [stmt for stmt in stmts if isinstance(stmt, py.GetItem)] @@ -69,7 +82,15 @@ def non_parametric_gates(): return q non_parametric_gates.print() - QASM2ToSquin(dialects=squin.kernel)(non_parametric_gates) + Walk( + Chain( + QASM2ToPyRule(), + qasm2_rules.QASM2CoreToSquin(), + qasm2_rules.QASM2IdToSquin(), + qasm2_rules.QASM2UOp1QToSquin(), + qasm2_rules.QASM2UOp2QToSquin(), + ) + ).rewrite(non_parametric_gates.code) AggressiveUnroll(dialects=squin.kernel).fixpoint(non_parametric_gates) actual_stmts = list(non_parametric_gates.callable_region.walk()) @@ -114,9 +135,15 @@ def rotation_gates(): return q rotation_gates.print() - QASM2ToSquin(dialects=squin.kernel)(rotation_gates) + # QASM2ToSquin(dialects=squin.kernel)(rotation_gates) + Walk( + Chain( + QASM2ToPyRule(), + qasm2_rules.QASM2CoreToSquin(), + qasm2_rules.QASM2ParametrizedUOp1QToSquin(), + ) + ).rewrite(rotation_gates.code) AggressiveUnroll(dialects=squin.kernel).fixpoint(rotation_gates) - rotation_gates.print() actual_stmts = list(rotation_gates.callable_region.walk()) actual_stmts = [ @@ -171,9 +198,15 @@ def noise_program(): return q noise_program.print() - QASM2ToSquin(dialects=noise_program.dialects)(noise_program) - AggressiveUnroll(dialects=noise_program.dialects).fixpoint(noise_program) - frame, _ = AddressAnalysis(dialects=noise_program.dialects).run(noise_program) + Walk( + Chain( + qasm2_rules.QASM2CoreToSquin(), + qasm2_rules.QASM2NoiseToSquin(), + ) + ).rewrite(noise_program.code) + AggressiveUnroll(dialects=squin.kernel).fixpoint(noise_program) + noise_program.print() + frame, _ = AddressAnalysis(dialects=squin.kernel).run(noise_program) actual_stmts = list(noise_program.callable_region.walk()) actual_stmts = [ @@ -224,7 +257,12 @@ def global_parallel_program(): return q global_parallel_program.print() - QASM2ToSquin(dialects=global_parallel_program.dialects)(global_parallel_program) + Walk( + Chain( + QASM2ToPyRule(), + qasm2_rules.QASM2GlobParallelToSquin(), + ) + ).rewrite(global_parallel_program.code) AggressiveUnroll(dialects=global_parallel_program.dialects).fixpoint( global_parallel_program ) @@ -267,7 +305,15 @@ def main_kernel(): sub_kernel(q[0], q[1]) return q - QASM2ToSquin(dialects=main_kernel.dialects)(main_kernel) + Walk( + Chain( + QASM2ToPyRule(), + qasm2_rules.QASM2CoreToSquin(), + ) + ).rewrite(main_kernel.code) + + QASM2GateFuncToSquinPass(dialects=main_kernel.dialects).unsafe_run(main_kernel) + AggressiveUnroll(dialects=main_kernel.dialects).fixpoint(main_kernel) actual_stmts = [type(stmt) for stmt in main_kernel.callable_region.walk()] From 5137de0a219e3f84e0642dbcc0efd795f2d3008b Mon Sep 17 00:00:00 2001 From: John Long Date: Sun, 7 Dec 2025 12:21:56 -0500 Subject: [PATCH 11/13] add in QASM2 adjoint to squin adjoint equivalents I missed the first round --- .../squin/rewrite/qasm2/uop_1q_to_squin.py | 5 +++- test/squin/passes/test_qasm2_to_squin.py | 23 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/bloqade/squin/rewrite/qasm2/uop_1q_to_squin.py b/src/bloqade/squin/rewrite/qasm2/uop_1q_to_squin.py index 06e179d0..c9bfb34e 100644 --- a/src/bloqade/squin/rewrite/qasm2/uop_1q_to_squin.py +++ b/src/bloqade/squin/rewrite/qasm2/uop_1q_to_squin.py @@ -11,8 +11,11 @@ uop_stmts.Z: squin.z, uop_stmts.H: squin.h, uop_stmts.S: squin.s, - uop_stmts.T: squin.t, + uop_stmts.Sdag: squin.s_adj, uop_stmts.SX: squin.sqrt_x, + uop_stmts.SXdag: squin.sqrt_x_adj, + uop_stmts.Tdag: squin.t_adj, + uop_stmts.T: squin.t, } diff --git a/test/squin/passes/test_qasm2_to_squin.py b/test/squin/passes/test_qasm2_to_squin.py index 5d05cf32..5829da28 100644 --- a/test/squin/passes/test_qasm2_to_squin.py +++ b/test/squin/passes/test_qasm2_to_squin.py @@ -71,8 +71,11 @@ def non_parametric_gates(): qasm2.z(q[0]) qasm2.h(q[0]) qasm2.s(q[0]) + qasm2.sdg(q[0]) qasm2.t(q[0]) + qasm2.tdg(q[0]) qasm2.sx(q[0]) + qasm2.sxdg(q[0]) # 2q gates qasm2.cx(q[0], q[1]) @@ -105,8 +108,11 @@ def non_parametric_gates(): squin.gate.stmts.Z, squin.gate.stmts.H, squin.gate.stmts.S, + squin.gate.stmts.S, # adjoint=True squin.gate.stmts.T, + squin.gate.stmts.T, # adjoint=True squin.gate.stmts.SqrtX, + squin.gate.stmts.SqrtX, # adjoint=True squin.gate.stmts.CX, squin.gate.stmts.CY, squin.gate.stmts.CZ, @@ -114,6 +120,23 @@ def non_parametric_gates(): assert [type(stmt) for stmt in actual_stmts] == expected_stmts + has_s_adj = False + has_t_adj = False + has_sqrtx_adj = False + + for stmt in actual_stmts: + + if type(stmt) is squin.gate.stmts.S and stmt.adjoint: + has_s_adj = True + elif type(stmt) is squin.gate.stmts.T and stmt.adjoint: + has_t_adj = True + elif type(stmt) is squin.gate.stmts.SqrtX and stmt.adjoint: + has_sqrtx_adj = True + + assert has_s_adj + assert has_t_adj + assert has_sqrtx_adj + def test_parametric_gates(): From 6390a2360296884e5d539f46fc43179cfd58219a Mon Sep 17 00:00:00 2001 From: John Long Date: Sun, 7 Dec 2025 13:09:57 -0500 Subject: [PATCH 12/13] final full test for everything --- test/squin/passes/test_qasm2_to_squin.py | 114 +++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/test/squin/passes/test_qasm2_to_squin.py b/test/squin/passes/test_qasm2_to_squin.py index 5829da28..b8dc8dc9 100644 --- a/test/squin/passes/test_qasm2_to_squin.py +++ b/test/squin/passes/test_qasm2_to_squin.py @@ -14,6 +14,7 @@ from bloqade.analysis.address import AddressAnalysis from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule from bloqade.analysis.address.lattice import AddressReg +from bloqade.squin.passes.qasm2_to_squin import QASM2ToSquin from bloqade.squin.passes.qasm2_gate_func_to_squin import QASM2GateFuncToSquinPass @@ -342,3 +343,116 @@ def main_kernel(): actual_stmts = [type(stmt) for stmt in main_kernel.callable_region.walk()] assert actual_stmts.count(squin.gate.stmts.CX) == 1 assert actual_stmts.count(squin.gate.stmts.X) == 1 + + +def test_ccx_to_2_and_1q_gates(): + + # use Nielsen and Chuang decomposition of CCX/Toffoli + # into 2 and 1 qubit gates. This is intentionally + # more complex than necessary just to test things out. + + @qasm2.main + def layer_1(ctrl1, ctrl2, target): + qasm2.h(target) + qasm2.cx(ctrl2, target) + qasm2.tdg(target) + qasm2.cx(ctrl1, target) + return + + @qasm2.main + def layer_2(ctrl1, ctrl2, target): + qasm2.t(ctrl1) + qasm2.cx(ctrl2, target) + qasm2.tdg(target) + qasm2.cx(ctrl1, target) + + @qasm2.main + def layer_3(ctrl1, ctrl2, target): + qasm2.t(ctrl2) + qasm2.t(target) + qasm2.cx(ctrl1, ctrl2) + qasm2.h(target) + qasm2.t(ctrl1) + qasm2.tdg(ctrl2) + qasm2.cx(ctrl1, ctrl2) + return + + @qasm2.extended + def lossy_toffoli_gate(ctrl1, ctrl2, target): + + layer_1(ctrl1, ctrl2, target) + noise.atom_loss_channel(qargs=[ctrl1, ctrl2, target], prob=0.01) + layer_2(ctrl1, ctrl2, target) + noise.atom_loss_channel(qargs=[ctrl1, ctrl2, target], prob=0.01) + layer_3(ctrl1, ctrl2, target) + noise.atom_loss_channel(qargs=[ctrl1, ctrl2, target], prob=0.01) + return + + @qasm2.extended + def rotation_layer(qs): + + glob.u(qs, math.pi, math.pi / 2, math.pi / 16) + parallel.rz(qs, math.pi / 4) + parallel.u(qs, math.pi / 2, math.pi / 2, math.pi / 8) + return + + @qasm2.main + def main(): + qs = qasm2.qreg(3) + # set up the control qubits + qasm2.x(qs[0]) + qasm2.x(qs[1]) + lossy_toffoli_gate(qs[0], qs[1], qs[2]) + rotation_layer(qs) + + QASM2ToSquin(dialects=main.dialects)(main) + AggressiveUnroll(dialects=main.dialects).fixpoint(main) + + actual_stmts = [ + type(stmt) + for stmt in main.callable_region.walk() + if isinstance(stmt, squin.gate.stmts.Gate) + or isinstance(stmt, squin.noise.stmts.NoiseChannel) + ] + expected_stmts = [ + squin.gate.stmts.X, + squin.gate.stmts.X, + # go into the toffoli + ## layer 1 + squin.gate.stmts.H, + squin.gate.stmts.CX, + squin.gate.stmts.T, # adjoint=True + squin.gate.stmts.CX, + ### atom loss + squin.noise.stmts.QubitLoss, + ## layer 2 + squin.gate.stmts.T, + squin.gate.stmts.CX, + squin.gate.stmts.T, # adjoint=True + squin.gate.stmts.CX, + ### atom loss + squin.noise.stmts.QubitLoss, + ## layer 3 + squin.gate.stmts.T, + squin.gate.stmts.T, + squin.gate.stmts.CX, + squin.gate.stmts.H, + squin.gate.stmts.T, + squin.gate.stmts.T, # adjoint=True + squin.gate.stmts.CX, + ### atom loss + squin.noise.stmts.QubitLoss, + # random rotation layer + squin.gate.stmts.U3, + squin.gate.stmts.Rz, + squin.gate.stmts.U3, + ] + + assert actual_stmts == expected_stmts + + num_T_adj = 0 + for stmt in main.callable_region.walk(): + if isinstance(stmt, squin.gate.stmts.T) and stmt.adjoint: + num_T_adj += 1 + + assert num_T_adj == 3 From 5f137de39d4a8870b40734bcf603f7aa3cbd57f2 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 9 Dec 2025 08:53:24 -0500 Subject: [PATCH 13/13] get rid of unnecessary util file --- .../squin/rewrite/qasm2/noise_to_squin.py | 17 ++++++++++++++--- src/bloqade/squin/rewrite/qasm2/util.py | 15 --------------- 2 files changed, 14 insertions(+), 18 deletions(-) delete mode 100644 src/bloqade/squin/rewrite/qasm2/util.py diff --git a/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py b/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py index 90b03553..938adeda 100644 --- a/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py +++ b/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py @@ -1,18 +1,29 @@ from kirin import ir -from kirin.dialects import func +from kirin.dialects import py, func from kirin.rewrite.abc import RewriteRule, RewriteResult from bloqade import squin from bloqade.qasm2.dialects.noise import stmts as noise_stmts -from .util import num_to_py_constant - NOISE_TO_SQUIN_MAP = { noise_stmts.AtomLossChannel: squin.broadcast.qubit_loss, noise_stmts.PauliChannel: squin.broadcast.single_qubit_pauli_channel, } +def num_to_py_constant( + values: list[int | float], stmt_to_insert_before: ir.Statement +) -> list[ir.SSAValue]: + + py_const_ssa_vals = [] + for value in values: + const_form = py.Constant(value=value) + const_form.insert_before(stmt_to_insert_before) + py_const_ssa_vals.append(const_form.result) + + return py_const_ssa_vals + + class QASM2NoiseToSquin(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: diff --git a/src/bloqade/squin/rewrite/qasm2/util.py b/src/bloqade/squin/rewrite/qasm2/util.py deleted file mode 100644 index a74528fd..00000000 --- a/src/bloqade/squin/rewrite/qasm2/util.py +++ /dev/null @@ -1,15 +0,0 @@ -from kirin import ir -from kirin.dialects import py - - -def num_to_py_constant( - values: list[int | float], stmt_to_insert_before: ir.Statement -) -> list[ir.SSAValue]: - - py_const_ssa_vals = [] - for value in values: - const_form = py.Constant(value=value) - const_form.insert_before(stmt_to_insert_before) - py_const_ssa_vals.append(const_form.result) - - return py_const_ssa_vals