diff --git a/src/bloqade/qasm2/dialects/uop/stmts.py b/src/bloqade/qasm2/dialects/uop/stmts.py index 755b50e55..55f556100 100644 --- a/src/bloqade/qasm2/dialects/uop/stmts.py +++ b/src/bloqade/qasm2/dialects/uop/stmts.py @@ -28,7 +28,7 @@ class TwoQubitCtrlGate(ir.Statement): @statement(dialect=dialect) class CX(TwoQubitCtrlGate): - """Alias for the CNOT or CH gate operations.""" + """Alias for the CNOT or CX gate operations.""" name = "CX" # Note this is capitalized diff --git a/src/bloqade/squin/passes/__init__.py b/src/bloqade/squin/passes/__init__.py new file mode 100644 index 000000000..6dceb5339 --- /dev/null +++ b/src/bloqade/squin/passes/__init__.py @@ -0,0 +1 @@ +from .qasm2_to_squin import QASM2ToSquin as QASM2ToSquin diff --git a/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py b/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py new file mode 100644 index 000000000..613616d2d --- /dev/null +++ b/src/bloqade/squin/passes/qasm2_gate_func_to_squin.py @@ -0,0 +1,57 @@ +from kirin import ir, passes +from kirin.rewrite import Walk, Chain +from kirin.dialects import func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade.rewrite.passes import CallGraphPass +from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule + +from ..rewrite import qasm2 as qasm2_rule + + +class QASM2GateFuncToKirinFunc(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + from bloqade.qasm2.dialects.expr.stmts import GateFunction + + if not isinstance(node, GateFunction): + return RewriteResult() + + kirin_func = func.Function( + sym_name=node.sym_name, + signature=node.signature, + body=node.body, + slots=node.slots, + ) + node.replace_by(kirin_func) + + return RewriteResult(has_done_something=True) + + +class QASM2GateFuncToSquinPass(passes.Pass): + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + convert_to_kirin_func = CallGraphPass( + dialects=mt.dialects, rule=Walk(QASM2GateFuncToKirinFunc()) + ) + rewrite_result = convert_to_kirin_func(mt) + + combined_qasm2_rules = Walk( + Chain( + QASM2ToPyRule(), + qasm2_rule.QASM2CoreToSquin(), + qasm2_rule.QASM2GlobParallelToSquin(), + qasm2_rule.QASM2NoiseToSquin(), + qasm2_rule.QASM2IdToSquin(), + qasm2_rule.QASM2UOp1QToSquin(), + qasm2_rule.QASM2ParametrizedUOp1QToSquin(), + qasm2_rule.QASM2UOp2QToSquin(), + ) + ) + + body_conversion_pass = CallGraphPass( + dialects=mt.dialects, rule=combined_qasm2_rules + ) + rewrite_result = body_conversion_pass(mt).join(rewrite_result) + + return rewrite_result diff --git a/src/bloqade/squin/passes/qasm2_to_squin.py b/src/bloqade/squin/passes/qasm2_to_squin.py new file mode 100644 index 000000000..d23605d7b --- /dev/null +++ b/src/bloqade/squin/passes/qasm2_to_squin.py @@ -0,0 +1,66 @@ +from dataclasses import dataclass + +from kirin import ir +from kirin.passes import Fold, Pass, TypeInfer +from kirin.rewrite import Walk, Chain +from kirin.rewrite.abc import RewriteResult +from kirin.dialects.ilist.passes import IListDesugar + +from bloqade import squin +from bloqade.squin.rewrite.qasm2 import ( + QASM2IdToSquin, + QASM2CoreToSquin, + QASM2NoiseToSquin, + QASM2UOp1QToSquin, + QASM2UOp2QToSquin, + QASM2GlobParallelToSquin, + QASM2ParametrizedUOp1QToSquin, +) + +# There's a QASM2Py pass that only applies an _QASM2Py rewrite rule, +# I just want the rule here. +from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule + +from .qasm2_gate_func_to_squin import QASM2GateFuncToSquinPass + + +@dataclass +class QASM2ToSquin(Pass): + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + + # rewrite all QASM2 to squin first + rewrite_result = Walk( + Chain( + QASM2ToPyRule(), + QASM2CoreToSquin(), + QASM2GlobParallelToSquin(), + QASM2NoiseToSquin(), + QASM2IdToSquin(), + QASM2UOp1QToSquin(), + QASM2ParametrizedUOp1QToSquin(), + QASM2UOp2QToSquin(), + ) + ).rewrite(mt.code) + + # go into subkernels + rewrite_result = ( + QASM2GateFuncToSquinPass(dialects=mt.dialects) + .unsafe_run(mt) + .join(rewrite_result) + ) + + # kernel should be entirely in squin dialect now + mt.dialects = squin.kernel + + # the rest is taken from the squin kernel + rewrite_result = Fold(dialects=mt.dialects).fixpoint(mt) + rewrite_result = ( + TypeInfer(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) + ) + rewrite_result = ( + IListDesugar(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) + ) + TypeInfer(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) + + return rewrite_result diff --git a/src/bloqade/squin/rewrite/qasm2/__init__.py b/src/bloqade/squin/rewrite/qasm2/__init__.py new file mode 100644 index 000000000..0cd85c7c2 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/__init__.py @@ -0,0 +1,9 @@ +from .id_to_squin import QASM2IdToSquin as QASM2IdToSquin +from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin +from .noise_to_squin import QASM2NoiseToSquin as QASM2NoiseToSquin +from .uop_1q_to_squin import QASM2UOp1QToSquin as QASM2UOp1QToSquin +from .uop_2q_to_squin import QASM2UOp2QToSquin as QASM2UOp2QToSquin +from .glob_parallel_to_squin import QASM2GlobParallelToSquin as QASM2GlobParallelToSquin +from .parametrized_uop_1q_to_squin import ( + QASM2ParametrizedUOp1QToSquin as QASM2ParametrizedUOp1QToSquin, +) diff --git a/src/bloqade/squin/rewrite/qasm2/core_to_squin.py b/src/bloqade/squin/rewrite/qasm2/core_to_squin.py new file mode 100644 index 000000000..29de3f486 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/core_to_squin.py @@ -0,0 +1,38 @@ +from kirin import ir +from kirin.dialects import py, func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.qasm2.dialects.core import stmts as core_stmts + +CORE_TO_SQUIN_MAP = { + core_stmts.QRegNew: squin.qubit.qalloc, + core_stmts.Reset: squin.qubit.reset, +} + + +class QASM2CoreToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + if isinstance(node, core_stmts.QRegGet): + py_get_item = py.GetItem( + obj=node.reg, + index=node.idx, + ) + node.replace_by(py_get_item) + return RewriteResult(has_done_something=True) + + if isinstance(node, core_stmts.QRegNew): + args = (node.n_qubits,) + elif isinstance(node, core_stmts.Reset): + args = (node.qarg,) + else: + return RewriteResult() + + new_stmt = func.Invoke( + callee=CORE_TO_SQUIN_MAP[type(node)], + inputs=args, + ) + node.replace_by(new_stmt) + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py b/src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py new file mode 100644 index 000000000..9fe52471f --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py @@ -0,0 +1,34 @@ +from kirin import ir +from kirin.dialects import func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.qasm2.dialects import glob, parallel + +GLOBAL_PARALLEL_TO_SQUIN_MAP = { + glob.UGate: squin.broadcast.u3, + parallel.UGate: squin.broadcast.u3, + parallel.RZ: squin.broadcast.rz, +} + + +class QASM2GlobParallelToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + if isinstance(node, glob.UGate): + args = (node.theta, node.phi, node.lam, node.registers) + elif isinstance(node, parallel.UGate): + args = (node.theta, node.phi, node.lam, node.qargs) + elif isinstance(node, parallel.RZ): + args = (node.theta, node.qargs) + else: + return RewriteResult() + + squin_equivalent_stmt = GLOBAL_PARALLEL_TO_SQUIN_MAP[type(node)] + invoke_stmt = func.Invoke( + callee=squin_equivalent_stmt, + inputs=args, + ) + node.replace_by(invoke_stmt) + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/id_to_squin.py b/src/bloqade/squin/rewrite/qasm2/id_to_squin.py new file mode 100644 index 000000000..b5f12f576 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/id_to_squin.py @@ -0,0 +1,15 @@ +from kirin import ir +from kirin.rewrite.abc import RewriteRule, RewriteResult + +import bloqade.qasm2.dialects.uop.stmts as uop_stmts + + +class QASM2IdToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + if not isinstance(node, uop_stmts.Id): + return RewriteResult() + + node.delete() + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py b/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py new file mode 100644 index 000000000..938adedad --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/noise_to_squin.py @@ -0,0 +1,91 @@ +from kirin import ir +from kirin.dialects import py, func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.qasm2.dialects.noise import stmts as noise_stmts + +NOISE_TO_SQUIN_MAP = { + noise_stmts.AtomLossChannel: squin.broadcast.qubit_loss, + noise_stmts.PauliChannel: squin.broadcast.single_qubit_pauli_channel, +} + + +def num_to_py_constant( + values: list[int | float], stmt_to_insert_before: ir.Statement +) -> list[ir.SSAValue]: + + py_const_ssa_vals = [] + for value in values: + const_form = py.Constant(value=value) + const_form.insert_before(stmt_to_insert_before) + py_const_ssa_vals.append(const_form.result) + + return py_const_ssa_vals + + +class QASM2NoiseToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + if isinstance(node, noise_stmts.AtomLossChannel): + qargs = node.qargs + prob = node.prob + prob_ssas = num_to_py_constant([prob], stmt_to_insert_before=node) + elif isinstance(node, noise_stmts.PauliChannel): + qargs = node.qargs + p_x = node.px + p_y = node.py + p_z = node.pz + prob_ssas = num_to_py_constant([p_x, p_y, p_z], stmt_to_insert_before=node) + elif isinstance(node, noise_stmts.CZPauliChannel): + return self.rewrite_CZPauliChannel(node) + else: + return RewriteResult() + + squin_noise_stmt = NOISE_TO_SQUIN_MAP[type(node)] + invoke_stmt = func.Invoke( + callee=squin_noise_stmt, + inputs=(*prob_ssas, qargs), + ) + node.replace_by(invoke_stmt) + return RewriteResult(has_done_something=True) + + def rewrite_CZPauliChannel(self, stmt: noise_stmts.CZPauliChannel) -> RewriteResult: + + ctrls = stmt.ctrls + qargs = stmt.qargs + + px_ctrl = stmt.px_ctrl + py_ctrl = stmt.py_ctrl + pz_ctrl = stmt.pz_ctrl + px_qarg = stmt.px_qarg + py_qarg = stmt.py_qarg + pz_qarg = stmt.pz_qarg + + error_probs = [px_ctrl, py_ctrl, pz_ctrl, px_qarg, py_qarg, pz_qarg] + # first half of entries for control qubits, other half for targets + + error_prob_ssas = num_to_py_constant(error_probs, stmt_to_insert_before=stmt) + + ctrl_pauli_channel_invoke = func.Invoke( + callee=squin.broadcast.single_qubit_pauli_channel, + inputs=( + *error_prob_ssas[:3], + ctrls, + ), + ) + + qarg_pauli_channel_invoke = func.Invoke( + callee=squin.broadcast.single_qubit_pauli_channel, + inputs=( + *error_prob_ssas[3:], + qargs, + ), + ) + + ctrl_pauli_channel_invoke.insert_before(stmt) + qarg_pauli_channel_invoke.insert_before(stmt) + stmt.delete() + + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/parametrized_uop_1q_to_squin.py b/src/bloqade/squin/rewrite/qasm2/parametrized_uop_1q_to_squin.py new file mode 100644 index 000000000..ffca71e0e --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/parametrized_uop_1q_to_squin.py @@ -0,0 +1,46 @@ +from math import pi + +from kirin import ir +from kirin.dialects import py, func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.qasm2.dialects.uop import stmts as uop_stmts + +PARAMETRIZED_1Q_GATES_TO_SQUIN_MAP = { + uop_stmts.UGate: squin.u3, + uop_stmts.U1: squin.u3, + uop_stmts.U2: squin.u3, + uop_stmts.RZ: squin.rz, + uop_stmts.RX: squin.rx, + uop_stmts.RY: squin.ry, +} + + +class QASM2ParametrizedUOp1QToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + if isinstance(node, (uop_stmts.RX, uop_stmts.RY, uop_stmts.RZ)): + args = (node.theta, node.qarg) + elif isinstance(node, (uop_stmts.UGate)): + args = (node.theta, node.phi, node.lam, node.qarg) + elif isinstance(node, (uop_stmts.U1)): + zero_stmt = py.Constant(value=0.0) + zero_stmt.insert_before(node) + args = (zero_stmt.result, zero_stmt.result, node.lam, node.qarg) + elif isinstance(node, (uop_stmts.U2)): + half_pi_stmt = py.Constant(value=pi / 2) + half_pi_stmt.insert_before(node) + args = (half_pi_stmt.result, node.phi, node.lam, node.qarg) + else: + return RewriteResult() + + squin_equivalent_stmt = PARAMETRIZED_1Q_GATES_TO_SQUIN_MAP[type(node)] + invoke_stmt = func.Invoke( + callee=squin_equivalent_stmt, + inputs=args, + ) + node.replace_by(invoke_stmt) + + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/uop_1q_to_squin.py b/src/bloqade/squin/rewrite/qasm2/uop_1q_to_squin.py new file mode 100644 index 000000000..c9bfb34e4 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/uop_1q_to_squin.py @@ -0,0 +1,36 @@ +from kirin import ir +from kirin.dialects import func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.qasm2.dialects.uop import stmts as uop_stmts + +ONE_Q_GATES_TO_SQUIN_MAP = { + uop_stmts.X: squin.x, + uop_stmts.Y: squin.y, + uop_stmts.Z: squin.z, + uop_stmts.H: squin.h, + uop_stmts.S: squin.s, + uop_stmts.Sdag: squin.s_adj, + uop_stmts.SX: squin.sqrt_x, + uop_stmts.SXdag: squin.sqrt_x_adj, + uop_stmts.Tdag: squin.t_adj, + uop_stmts.T: squin.t, +} + + +class QASM2UOp1QToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + squin_1q_gate = ONE_Q_GATES_TO_SQUIN_MAP.get(type(node)) + if squin_1q_gate is None: + return RewriteResult() + + invoke_stmt = func.Invoke( + callee=squin_1q_gate, + inputs=(node.qarg,), + ) + node.replace_by(invoke_stmt) + + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/qasm2/uop_2q_to_squin.py b/src/bloqade/squin/rewrite/qasm2/uop_2q_to_squin.py new file mode 100644 index 000000000..5aeeb7939 --- /dev/null +++ b/src/bloqade/squin/rewrite/qasm2/uop_2q_to_squin.py @@ -0,0 +1,28 @@ +from kirin import ir +from kirin.dialects import func +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import squin +from bloqade.qasm2.dialects.uop import stmts as uop_stmts + +CONTROLLED_GATES_TO_SQUIN_MAP = { + uop_stmts.CX: squin.cx, + uop_stmts.CZ: squin.cz, + uop_stmts.CY: squin.cy, +} + + +class QASM2UOp2QToSquin(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + squin_controlled_gate = CONTROLLED_GATES_TO_SQUIN_MAP.get(type(node)) + if squin_controlled_gate is None: + return RewriteResult() + + invoke_stmt = func.Invoke( + callee=squin_controlled_gate, + inputs=(node.ctrl, node.qarg), + ) + node.replace_by(invoke_stmt) + return RewriteResult(has_done_something=True) diff --git a/test/analysis/measure_id/test_refactored_measure_id.py b/test/analysis/measure_id/test_refactored_measure_id.py new file mode 100644 index 000000000..107ed1cfe --- /dev/null +++ b/test/analysis/measure_id/test_refactored_measure_id.py @@ -0,0 +1,32 @@ +import io + +from kirin import ir + +from bloqade import stim, squin +from bloqade.stim.emit import EmitStimMain +from bloqade.stim.passes import SquinToStimPass + + +def codegen(mt: ir.Method): + # method should not have any arguments! + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) + emit.initialize() + emit.run(mt) + return buf.getvalue().strip() + + +@squin.kernel +def test_simple_linear(): + + qs = squin.qalloc(4) + m0 = squin.broadcast.measure(qs) + squin.set_detector([m0[0], m0[1]], coordinates=[0, 0]) + m1 = squin.broadcast.measure(qs) + squin.set_detector([m1[0], m1[1]], coordinates=[1, 1]) + + +test_simple_linear.print() +SquinToStimPass(dialects=test_simple_linear.dialects)(test_simple_linear) +test_simple_linear.print() +print(codegen(test_simple_linear)) diff --git a/test/squin/passes/test_qasm2_to_squin.py b/test/squin/passes/test_qasm2_to_squin.py new file mode 100644 index 000000000..b8dc8dc96 --- /dev/null +++ b/test/squin/passes/test_qasm2_to_squin.py @@ -0,0 +1,458 @@ +import math + +from kirin import types as kirin_types +from kirin.passes import TypeInfer +from kirin.rewrite import Walk, Chain +from kirin.dialects import py +from kirin.dialects.ilist import IListType + +import bloqade.squin.rewrite.qasm2 as qasm2_rules +from bloqade import qasm2, squin +from bloqade.qasm2 import glob, noise, parallel +from bloqade.types import QubitType +from bloqade.rewrite.passes import AggressiveUnroll +from bloqade.analysis.address import AddressAnalysis +from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule +from bloqade.analysis.address.lattice import AddressReg +from bloqade.squin.passes.qasm2_to_squin import QASM2ToSquin +from bloqade.squin.passes.qasm2_gate_func_to_squin import QASM2GateFuncToSquinPass + + +def test_qasm2_core(): + + @qasm2.main + def core_kernel(): + q = qasm2.qreg(5) + q0 = q[0] + qasm2.reset(q0) + return q0 + + Walk( + Chain( + QASM2ToPyRule(), + qasm2_rules.QASM2CoreToSquin(), + ) + ).rewrite(core_kernel.code) + + TypeInfer(dialects=squin.kernel)(core_kernel) + + core_kernel.print() + + stmts = list(core_kernel.callable_region.walk()) + get_item_stmt = [stmt for stmt in stmts if isinstance(stmt, py.GetItem)] + assert len(get_item_stmt) == 1 + get_item_stmt = get_item_stmt[0] + assert get_item_stmt.obj.type == IListType[QubitType, kirin_types.Any] + idx_const = get_item_stmt.index.owner + assert idx_const.value.data == 0 + + # do aggressive unroll, confirm there are 5 qubit.news() + AggressiveUnroll(dialects=squin.kernel).fixpoint(core_kernel) + + unrolled_stmts = list(core_kernel.callable_region.walk()) + filtered_stmts = [ + stmt + for stmt in unrolled_stmts + if isinstance(stmt, (squin.qubit.stmts.New, squin.qubit.stmts.Reset)) + ] + expected_stmts = [squin.qubit.stmts.New] * 5 + [squin.qubit.stmts.Reset] + + assert [type(stmt) for stmt in filtered_stmts] == expected_stmts + + +def test_non_parametric_gates(): + + @qasm2.main + def non_parametric_gates(): + # 1q gates + q = qasm2.qreg(1) + qasm2.id(q[0]) + qasm2.x(q[0]) + qasm2.y(q[0]) + qasm2.z(q[0]) + qasm2.h(q[0]) + qasm2.s(q[0]) + qasm2.sdg(q[0]) + qasm2.t(q[0]) + qasm2.tdg(q[0]) + qasm2.sx(q[0]) + qasm2.sxdg(q[0]) + + # 2q gates + qasm2.cx(q[0], q[1]) + qasm2.cy(q[0], q[1]) + qasm2.cz(q[0], q[1]) + + return q + + non_parametric_gates.print() + Walk( + Chain( + QASM2ToPyRule(), + qasm2_rules.QASM2CoreToSquin(), + qasm2_rules.QASM2IdToSquin(), + qasm2_rules.QASM2UOp1QToSquin(), + qasm2_rules.QASM2UOp2QToSquin(), + ) + ).rewrite(non_parametric_gates.code) + AggressiveUnroll(dialects=squin.kernel).fixpoint(non_parametric_gates) + + actual_stmts = list(non_parametric_gates.callable_region.walk()) + # should be no identity whatsoever + assert not any(isinstance(stmt, qasm2.uop.stmts.Id) for stmt in actual_stmts) + actual_stmts = [ + stmt for stmt in actual_stmts if isinstance(stmt, squin.gate.stmts.Gate) + ] + expected_stmts = [ + squin.gate.stmts.X, + squin.gate.stmts.Y, + squin.gate.stmts.Z, + squin.gate.stmts.H, + squin.gate.stmts.S, + squin.gate.stmts.S, # adjoint=True + squin.gate.stmts.T, + squin.gate.stmts.T, # adjoint=True + squin.gate.stmts.SqrtX, + squin.gate.stmts.SqrtX, # adjoint=True + squin.gate.stmts.CX, + squin.gate.stmts.CY, + squin.gate.stmts.CZ, + ] + + assert [type(stmt) for stmt in actual_stmts] == expected_stmts + + has_s_adj = False + has_t_adj = False + has_sqrtx_adj = False + + for stmt in actual_stmts: + + if type(stmt) is squin.gate.stmts.S and stmt.adjoint: + has_s_adj = True + elif type(stmt) is squin.gate.stmts.T and stmt.adjoint: + has_t_adj = True + elif type(stmt) is squin.gate.stmts.SqrtX and stmt.adjoint: + has_sqrtx_adj = True + + assert has_s_adj + assert has_t_adj + assert has_sqrtx_adj + + +def test_parametric_gates(): + + const_pi = math.pi + + @qasm2.main + def rotation_gates(): + q = qasm2.qreg(3) + half_turn = const_pi + quarter_turn = const_pi / 2 + eighth_turn = const_pi / 4 + qasm2.rz(q[0], half_turn) + qasm2.rx(q[1], half_turn) + qasm2.ry(q[2], half_turn) + + qasm2.u3(q[0], half_turn, quarter_turn, eighth_turn) + qasm2.u2(q[0], quarter_turn, half_turn) + qasm2.u1(q[0], eighth_turn) + return q + + rotation_gates.print() + # QASM2ToSquin(dialects=squin.kernel)(rotation_gates) + Walk( + Chain( + QASM2ToPyRule(), + qasm2_rules.QASM2CoreToSquin(), + qasm2_rules.QASM2ParametrizedUOp1QToSquin(), + ) + ).rewrite(rotation_gates.code) + AggressiveUnroll(dialects=squin.kernel).fixpoint(rotation_gates) + + actual_stmts = list(rotation_gates.callable_region.walk()) + actual_stmts = [ + stmt for stmt in actual_stmts if isinstance(stmt, squin.gate.stmts.Gate) + ] + + assert len(actual_stmts) == 6 + + assert type(actual_stmts[0]) is squin.gate.stmts.Rz + assert actual_stmts[0].angle.owner.value.data == 0.5 + assert type(actual_stmts[1]) is squin.gate.stmts.Rx + assert actual_stmts[1].angle.owner.value.data == 0.5 + assert type(actual_stmts[2]) is squin.gate.stmts.Ry + assert actual_stmts[2].angle.owner.value.data == 0.5 + + # U3 + assert type(actual_stmts[3]) is squin.gate.stmts.U3 + assert actual_stmts[3].theta.owner.value.data == 0.5 + assert actual_stmts[3].phi.owner.value.data == 0.25 + assert actual_stmts[3].lam.owner.value.data == 0.125 + + # U2 is just U3(pi/2, phi, lam) + assert type(actual_stmts[4]) is squin.gate.stmts.U3 + assert actual_stmts[4].theta.owner.value.data == 0.25 + assert actual_stmts[4].phi.owner.value.data == 0.25 + assert actual_stmts[4].lam.owner.value.data == 0.5 + + assert type(actual_stmts[5]) is squin.gate.stmts.U3 + assert actual_stmts[5].theta.owner.value.data == 0.0 + assert actual_stmts[5].phi.owner.value.data == 0.0 + assert actual_stmts[5].lam.owner.value.data == 0.125 + + +def test_noise(): + + @qasm2.extended + def noise_program(): + q = qasm2.qreg(4) + noise.atom_loss_channel(qargs=q, prob=0.05) + noise.pauli_channel(qargs=q, px=0.01, py=0.02, pz=0.03) + noise.cz_pauli_channel( + ctrls=[q[0], q[1]], + qargs=[q[2], q[3]], + px_ctrl=0.01, + py_ctrl=0.02, + pz_ctrl=0.03, + px_qarg=0.04, + py_qarg=0.05, + pz_qarg=0.06, + paired=True, + ) + return q + + noise_program.print() + Walk( + Chain( + qasm2_rules.QASM2CoreToSquin(), + qasm2_rules.QASM2NoiseToSquin(), + ) + ).rewrite(noise_program.code) + AggressiveUnroll(dialects=squin.kernel).fixpoint(noise_program) + noise_program.print() + frame, _ = AddressAnalysis(dialects=squin.kernel).run(noise_program) + + actual_stmts = list(noise_program.callable_region.walk()) + actual_stmts = [ + stmt + for stmt in actual_stmts + if isinstance(stmt, squin.noise.stmts.NoiseChannel) + ] + + assert type(actual_stmts[0]) is squin.noise.stmts.QubitLoss + assert actual_stmts[0].p.owner.value.data == 0.05 + + assert type(actual_stmts[1]) is squin.noise.stmts.SingleQubitPauliChannel + assert actual_stmts[1].px.owner.value.data == 0.01 + assert actual_stmts[1].py.owner.value.data == 0.02 + assert actual_stmts[1].pz.owner.value.data == 0.03 + + # originate from the same cz_pauli_channel + assert type(actual_stmts[2]) is squin.noise.stmts.SingleQubitPauliChannel + assert type(actual_stmts[3]) is squin.noise.stmts.SingleQubitPauliChannel + # control qubits + assert frame.get(actual_stmts[2].qubits) == AddressReg(data=(0, 1)) + # target qubits + assert frame.get(actual_stmts[3].qubits) == AddressReg(data=(2, 3)) + # assert probabilities are correct on control + assert actual_stmts[2].px.owner.value.data == 0.01 + assert actual_stmts[2].py.owner.value.data == 0.02 + assert actual_stmts[2].pz.owner.value.data == 0.03 + # assert probabilities are correct on targets + assert actual_stmts[3].px.owner.value.data == 0.04 + assert actual_stmts[3].py.owner.value.data == 0.05 + assert actual_stmts[3].pz.owner.value.data == 0.06 + + +def test_global_and_parallel(): + + const_pi = math.pi + + @qasm2.extended + def global_parallel_program(): + q = qasm2.qreg(6) + half_turn = const_pi + quarter_turn = const_pi / 2 + eighth_turn = const_pi / 4 + + parallel.u([q[0]], half_turn, quarter_turn, eighth_turn) + glob.u([q[2], q[1], q[3]], half_turn, quarter_turn, eighth_turn) + parallel.rz([q[4], q[5]], eighth_turn) + return q + + global_parallel_program.print() + Walk( + Chain( + QASM2ToPyRule(), + qasm2_rules.QASM2GlobParallelToSquin(), + ) + ).rewrite(global_parallel_program.code) + AggressiveUnroll(dialects=global_parallel_program.dialects).fixpoint( + global_parallel_program + ) + + actual_stmts = list(global_parallel_program.callable_region.walk()) + actual_stmts = [ + stmt for stmt in actual_stmts if isinstance(stmt, squin.gate.stmts.Gate) + ] + + assert type(actual_stmts[0]) is squin.gate.stmts.U3 + assert actual_stmts[0].theta.owner.value.data == 0.5 + assert actual_stmts[0].phi.owner.value.data == 0.25 + assert actual_stmts[0].lam.owner.value.data == 0.125 + + assert type(actual_stmts[1]) is squin.gate.stmts.U3 + assert actual_stmts[1].theta.owner.value.data == 0.5 + assert actual_stmts[1].phi.owner.value.data == 0.25 + assert actual_stmts[1].lam.owner.value.data == 0.125 + + assert type(actual_stmts[2]) is squin.gate.stmts.Rz + assert actual_stmts[2].angle.owner.value.data == 0.125 + + +def test_func(): + + @qasm2.main + def another_sub_kernel(q): + qasm2.x(q) + return + + @qasm2.main + def sub_kernel(ctrl, target): + qasm2.cx(ctrl, target) + another_sub_kernel(ctrl) + return + + @qasm2.main + def main_kernel(): + q = qasm2.qreg(2) + sub_kernel(q[0], q[1]) + return q + + Walk( + Chain( + QASM2ToPyRule(), + qasm2_rules.QASM2CoreToSquin(), + ) + ).rewrite(main_kernel.code) + + QASM2GateFuncToSquinPass(dialects=main_kernel.dialects).unsafe_run(main_kernel) + + AggressiveUnroll(dialects=main_kernel.dialects).fixpoint(main_kernel) + + actual_stmts = [type(stmt) for stmt in main_kernel.callable_region.walk()] + assert actual_stmts.count(squin.gate.stmts.CX) == 1 + assert actual_stmts.count(squin.gate.stmts.X) == 1 + + +def test_ccx_to_2_and_1q_gates(): + + # use Nielsen and Chuang decomposition of CCX/Toffoli + # into 2 and 1 qubit gates. This is intentionally + # more complex than necessary just to test things out. + + @qasm2.main + def layer_1(ctrl1, ctrl2, target): + qasm2.h(target) + qasm2.cx(ctrl2, target) + qasm2.tdg(target) + qasm2.cx(ctrl1, target) + return + + @qasm2.main + def layer_2(ctrl1, ctrl2, target): + qasm2.t(ctrl1) + qasm2.cx(ctrl2, target) + qasm2.tdg(target) + qasm2.cx(ctrl1, target) + + @qasm2.main + def layer_3(ctrl1, ctrl2, target): + qasm2.t(ctrl2) + qasm2.t(target) + qasm2.cx(ctrl1, ctrl2) + qasm2.h(target) + qasm2.t(ctrl1) + qasm2.tdg(ctrl2) + qasm2.cx(ctrl1, ctrl2) + return + + @qasm2.extended + def lossy_toffoli_gate(ctrl1, ctrl2, target): + + layer_1(ctrl1, ctrl2, target) + noise.atom_loss_channel(qargs=[ctrl1, ctrl2, target], prob=0.01) + layer_2(ctrl1, ctrl2, target) + noise.atom_loss_channel(qargs=[ctrl1, ctrl2, target], prob=0.01) + layer_3(ctrl1, ctrl2, target) + noise.atom_loss_channel(qargs=[ctrl1, ctrl2, target], prob=0.01) + return + + @qasm2.extended + def rotation_layer(qs): + + glob.u(qs, math.pi, math.pi / 2, math.pi / 16) + parallel.rz(qs, math.pi / 4) + parallel.u(qs, math.pi / 2, math.pi / 2, math.pi / 8) + return + + @qasm2.main + def main(): + qs = qasm2.qreg(3) + # set up the control qubits + qasm2.x(qs[0]) + qasm2.x(qs[1]) + lossy_toffoli_gate(qs[0], qs[1], qs[2]) + rotation_layer(qs) + + QASM2ToSquin(dialects=main.dialects)(main) + AggressiveUnroll(dialects=main.dialects).fixpoint(main) + + actual_stmts = [ + type(stmt) + for stmt in main.callable_region.walk() + if isinstance(stmt, squin.gate.stmts.Gate) + or isinstance(stmt, squin.noise.stmts.NoiseChannel) + ] + expected_stmts = [ + squin.gate.stmts.X, + squin.gate.stmts.X, + # go into the toffoli + ## layer 1 + squin.gate.stmts.H, + squin.gate.stmts.CX, + squin.gate.stmts.T, # adjoint=True + squin.gate.stmts.CX, + ### atom loss + squin.noise.stmts.QubitLoss, + ## layer 2 + squin.gate.stmts.T, + squin.gate.stmts.CX, + squin.gate.stmts.T, # adjoint=True + squin.gate.stmts.CX, + ### atom loss + squin.noise.stmts.QubitLoss, + ## layer 3 + squin.gate.stmts.T, + squin.gate.stmts.T, + squin.gate.stmts.CX, + squin.gate.stmts.H, + squin.gate.stmts.T, + squin.gate.stmts.T, # adjoint=True + squin.gate.stmts.CX, + ### atom loss + squin.noise.stmts.QubitLoss, + # random rotation layer + squin.gate.stmts.U3, + squin.gate.stmts.Rz, + squin.gate.stmts.U3, + ] + + assert actual_stmts == expected_stmts + + num_T_adj = 0 + for stmt in main.callable_region.walk(): + if isinstance(stmt, squin.gate.stmts.T) and stmt.adjoint: + num_T_adj += 1 + + assert num_T_adj == 3