diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index ee5d8414..136bcab6 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -206,6 +206,20 @@ def apply( ): return frame.get_values(stmt.inputs) + @interp.impl(squin.wire.MeasureAndReset) + def measure_and_reset( + self, + interp_: AddressAnalysis, + frame: ForwardFrame[Address], + stmt: squin.wire.MeasureAndReset, + ): + + # take the address data from the incoming wire + # and propagate that forward to the new wire generated. + # The first entry can safely be NotQubit because + # it's an integer + return (NotQubit(), frame.get(stmt.wire)) + @squin.qubit.dialect.register(key="qubit.address") class SquinQubitMethodTable(interp.MethodTable): diff --git a/src/bloqade/squin/analysis/nsites/__init__.py b/src/bloqade/squin/analysis/nsites/__init__.py index da0a8e86..e3177322 100644 --- a/src/bloqade/squin/analysis/nsites/__init__.py +++ b/src/bloqade/squin/analysis/nsites/__init__.py @@ -1,6 +1,7 @@ # Need this for impl registration to work properly! from . import impls as impls from .lattice import ( + Sites as Sites, NoSites as NoSites, AnySites as AnySites, NumberSites as NumberSites, diff --git a/src/bloqade/squin/analysis/nsites/impls.py b/src/bloqade/squin/analysis/nsites/impls.py index 74b6b759..e8089bc9 100644 --- a/src/bloqade/squin/analysis/nsites/impls.py +++ b/src/bloqade/squin/analysis/nsites/impls.py @@ -1,6 +1,6 @@ from kirin import interp -from bloqade.squin import op +from bloqade.squin import op, wire from .lattice import ( NoSites, @@ -9,6 +9,30 @@ from .analysis import NSitesAnalysis +@wire.dialect.register(key="op.nsites") +class SquinWire(interp.MethodTable): + + @interp.impl(wire.Apply) + @interp.impl(wire.Broadcast) + def apply( + self, + interp: NSitesAnalysis, + frame: interp.Frame, + stmt: wire.Apply | wire.Broadcast, + ): + + return tuple(frame.get(input) for input in stmt.inputs) + + @interp.impl(wire.MeasureAndReset) + def measure_and_reset( + self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.MeasureAndReset + ): + + # MeasureAndReset produces both a new wire + # and an integer which don't have any sites at all + return (NoSites(), NoSites()) + + @op.dialect.register(key="op.nsites") class SquinOp(interp.MethodTable): diff --git a/src/bloqade/squin/passes/__init__.py b/src/bloqade/squin/passes/__init__.py new file mode 100644 index 00000000..6368db40 --- /dev/null +++ b/src/bloqade/squin/passes/__init__.py @@ -0,0 +1 @@ +from .stim import SquinToStim as SquinToStim diff --git a/src/bloqade/squin/passes/stim.py b/src/bloqade/squin/passes/stim.py new file mode 100644 index 00000000..d919173c --- /dev/null +++ b/src/bloqade/squin/passes/stim.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass + +from kirin.passes import Fold +from kirin.rewrite import ( + Walk, + Chain, + Fixpoint, + DeadCodeElimination, + CommonSubexpressionElimination, +) +from kirin.ir.method import Method +from kirin.passes.abc import Pass +from kirin.rewrite.abc import RewriteResult + +from bloqade.squin.rewrite import ( + SquinWireToStim, + SquinQubitToStim, + WrapSquinAnalysis, + SquinMeasureToStim, + SquinWireIdentityElimination, +) +from bloqade.analysis.address import AddressAnalysis +from bloqade.squin.analysis.nsites import ( + NSitesAnalysis, +) + + +@dataclass +class SquinToStim(Pass): + + def unsafe_run(self, mt: Method) -> RewriteResult: + fold_pass = Fold(mt.dialects) + # propagate constants + rewrite_result = fold_pass(mt) + + # Get necessary analysis results to plug into hints + address_analysis = AddressAnalysis(mt.dialects) + address_frame, _ = address_analysis.run_analysis(mt) + site_analysis = NSitesAnalysis(mt.dialects) + sites_frame, _ = site_analysis.run_analysis(mt) + + # Wrap Rewrite + SquinToStim can happen w/ standard walk + rewrite_result = ( + Walk( + Chain( + WrapSquinAnalysis( + address_analysis=address_frame.entries, + op_site_analysis=sites_frame.entries, + ), + SquinQubitToStim(), + SquinWireToStim(), + SquinMeasureToStim(), # reduce duplicated logic, can split out even more rules later + SquinWireIdentityElimination(), + ) + ) + .rewrite(mt.code) + .join(rewrite_result) + ) + + rewrite_result = ( + Fixpoint( + Walk(Chain(DeadCodeElimination(), CommonSubexpressionElimination())) + ) + .rewrite(mt.code) + .join(rewrite_result) + ) + + return rewrite_result diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py index e69de29b..1280ef2c 100644 --- a/src/bloqade/squin/rewrite/__init__.py +++ b/src/bloqade/squin/rewrite/__init__.py @@ -0,0 +1,11 @@ +from .wire_to_stim import SquinWireToStim as SquinWireToStim +from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim +from .squin_measure import SquinMeasureToStim as SquinMeasureToStim +from .wrap_analysis import ( + SitesAttribute as SitesAttribute, + AddressAttribute as AddressAttribute, + WrapSquinAnalysis as WrapSquinAnalysis, +) +from .wire_identity_elimination import ( + SquinWireIdentityElimination as SquinWireIdentityElimination, +) diff --git a/src/bloqade/squin/rewrite/qubit_to_stim.py b/src/bloqade/squin/rewrite/qubit_to_stim.py new file mode 100644 index 00000000..6a414d34 --- /dev/null +++ b/src/bloqade/squin/rewrite/qubit_to_stim.py @@ -0,0 +1,84 @@ +from kirin import ir +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import stim +from bloqade.squin import op, qubit +from bloqade.squin.rewrite.wrap_analysis import AddressAttribute +from bloqade.squin.rewrite.stim_rewrite_util import ( + SQUIN_STIM_GATE_MAPPING, + rewrite_Control, + insert_qubit_idx_from_address, +) + + +class SquinQubitToStim(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + match node: + case qubit.Apply() | qubit.Broadcast(): + return self.rewrite_Apply_and_Broadcast(node) + case qubit.Reset(): + return self.rewrite_Reset(node) + case _: + return RewriteResult() + + def rewrite_Apply_and_Broadcast( + self, stmt: qubit.Apply | qubit.Broadcast + ) -> RewriteResult: + """ + Rewrite Apply and Broadcast nodes to their stim equivalent statements. + """ + + # this is an SSAValue, need it to be the actual operator + applied_op = stmt.operator.owner + assert isinstance(applied_op, op.stmts.Operator) + + if isinstance(applied_op, op.stmts.Control): + return rewrite_Control(stmt) + + # need to handle Control through separate means + # but we can handle X, Y, Z, H, and S here just fine + stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op)) + if stim_1q_op is None: + return RewriteResult() + + address_attr = stmt.qubits.hints.get("address") + if address_attr is None: + return RewriteResult() + + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=stmt + ) + + if qubit_idx_ssas is None: + return RewriteResult() + + stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) + stmt.replace_by(stim_1q_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_Reset(self, reset_stmt: qubit.Reset) -> RewriteResult: + qubit_ilist_ssa = reset_stmt.qubits + # qubits are in an ilist which makes up an AddressTuple + address_attr = qubit_ilist_ssa.hints.get("address") + if address_attr is None: + return RewriteResult() + + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=reset_stmt + ) + + if qubit_idx_ssas is None: + return RewriteResult() + + stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas) + reset_stmt.replace_by(stim_rz_stmt) + + return RewriteResult(has_done_something=True) + + +# put rewrites for measure statements in separate rule, then just have to dispatch diff --git a/src/bloqade/squin/rewrite/squin_measure.py b/src/bloqade/squin/rewrite/squin_measure.py new file mode 100644 index 00000000..ee397233 --- /dev/null +++ b/src/bloqade/squin/rewrite/squin_measure.py @@ -0,0 +1,98 @@ +# create rewrite rule name SquinMeasureToStim using kirin +from kirin import ir +from kirin.dialects import py +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import stim +from bloqade.squin import wire, qubit +from bloqade.squin.rewrite.wrap_analysis import AddressAttribute +from bloqade.squin.rewrite.stim_rewrite_util import ( + is_measure_result_used, + insert_qubit_idx_from_address, +) + + +class SquinMeasureToStim(RewriteRule): + """ + Rewrite squin measure-related statements to stim statements. + """ + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + match node: + case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure(): + return self.rewrite_Measure(node) + case qubit.MeasureAndReset() | wire.MeasureAndReset(): + return self.rewrite_MeasureAndReset(node) + case _: + return RewriteResult() + + def rewrite_Measure( + self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure + ) -> RewriteResult: + if is_measure_result_used(measure_stmt): + return RewriteResult() + + qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt) + if qubit_idx_ssas is None: + return RewriteResult() + + prob_noise_stmt = py.constant.Constant(0.0) + stim_measure_stmt = stim.collapse.MZ( + p=prob_noise_stmt.result, + targets=qubit_idx_ssas, + ) + prob_noise_stmt.insert_before(measure_stmt) + measure_stmt.replace_by(stim_measure_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_MeasureAndReset( + self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset + ) -> RewriteResult: + if not is_measure_result_used(meas_and_reset_stmt): + return RewriteResult() + + qubit_idx_ssas = self.get_qubit_idx_ssas(meas_and_reset_stmt) + + if qubit_idx_ssas is None: + return RewriteResult() + + error_p_stmt = py.Constant(0.0) + stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result) + stim_rz_stmt = stim.collapse.RZ( + targets=qubit_idx_ssas, + ) + + error_p_stmt.insert_before(meas_and_reset_stmt) + stim_mz_stmt.insert_before(meas_and_reset_stmt) + meas_and_reset_stmt.replace_by(stim_rz_stmt) + + return RewriteResult(has_done_something=True) + + def get_qubit_idx_ssas( + self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure + ) -> tuple[ir.SSAValue, ...] | None: + """ + Extract the address attribute and insert qubit indices for the given measure statement. + """ + match measure_stmt: + case qubit.MeasureQubit(): + address_attr = measure_stmt.qubit.hints.get("address") + case qubit.MeasureQubitList(): + address_attr = measure_stmt.qubits.hints.get("address") + case wire.Measure(): + address_attr = measure_stmt.wire.hints.get("address") + case _: + return None + + if address_attr is None: + return None + + assert isinstance(address_attr, AddressAttribute) + + qubit_idx_ssas = insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=measure_stmt + ) + + return qubit_idx_ssas diff --git a/src/bloqade/squin/rewrite/stim_rewrite_util.py b/src/bloqade/squin/rewrite/stim_rewrite_util.py new file mode 100644 index 00000000..1148558a --- /dev/null +++ b/src/bloqade/squin/rewrite/stim_rewrite_util.py @@ -0,0 +1,158 @@ +from kirin import ir +from kirin.dialects import py +from kirin.rewrite.abc import RewriteResult + +from bloqade import stim +from bloqade.squin import op, wire, qubit +from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple +from bloqade.squin.rewrite.wrap_analysis import AddressAttribute + +SQUIN_STIM_GATE_MAPPING = { + op.stmts.X: stim.gate.X, + op.stmts.Y: stim.gate.Y, + op.stmts.Z: stim.gate.Z, + op.stmts.H: stim.gate.H, + op.stmts.S: stim.gate.S, + op.stmts.Identity: stim.gate.Identity, +} + + +def insert_qubit_idx_from_address( + address: AddressAttribute, stmt_to_insert_before: ir.Statement +) -> tuple[ir.SSAValue, ...] | None: + """ + Extract qubit indices from an AddressAttribute and insert them into the SSA form. + """ + address_data = address.address + qubit_idx_ssas = [] + + if isinstance(address_data, AddressTuple): + for address_qubit in address_data.data: + if not isinstance(address_qubit, AddressQubit): + return + qubit_idx = address_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + qubit_idx_stmt.insert_before(stmt_to_insert_before) + qubit_idx_ssas.append(qubit_idx_stmt.result) + elif isinstance(address_data, AddressWire): + address_qubit = address_data.origin_qubit + qubit_idx = address_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + qubit_idx_stmt.insert_before(stmt_to_insert_before) + qubit_idx_ssas.append(qubit_idx_stmt.result) + else: + return + + return tuple(qubit_idx_ssas) + + +def insert_qubit_idx_from_wire_ssa( + wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement +) -> tuple[ir.SSAValue, ...] | None: + """ + Extract qubit indices from wire SSA values and insert them into the SSA form. + """ + qubit_idx_ssas = [] + for wire_ssa in wire_ssas: + address_attribute = wire_ssa.hints.get("address") + if address_attribute is None: + return + assert isinstance(address_attribute, AddressAttribute) + wire_address = address_attribute.address + assert isinstance(wire_address, AddressWire) + qubit_idx = wire_address.origin_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + qubit_idx_ssas.append(qubit_idx_stmt.result) + qubit_idx_stmt.insert_before(stmt_to_insert_before) + + return tuple(qubit_idx_ssas) + + +def insert_qubit_idx_after_apply( + stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast, +) -> tuple[ir.SSAValue, ...] | None: + """ + Extract qubit indices from Apply or Broadcast statements. + """ + if isinstance(stmt, (qubit.Apply, qubit.Broadcast)): + qubits = stmt.qubits + address_attribute = qubits.hints.get("address") + if address_attribute is None: + return + assert isinstance(address_attribute, AddressAttribute) + return insert_qubit_idx_from_address( + address=address_attribute, stmt_to_insert_before=stmt + ) + elif isinstance(stmt, (wire.Apply, wire.Broadcast)): + wire_ssas = stmt.inputs + return insert_qubit_idx_from_wire_ssa( + wire_ssas=wire_ssas, stmt_to_insert_before=stmt + ) + + +def rewrite_Control( + stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast, +) -> RewriteResult: + """ + Handle control gates for Apply and Broadcast statements. + """ + ctrl_op = stmt_with_ctrl.operator.owner + assert isinstance(ctrl_op, op.stmts.Control) + + ctrl_op_target_gate = ctrl_op.op.owner + assert isinstance(ctrl_op_target_gate, op.stmts.Operator) + + qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt_with_ctrl) + if qubit_idx_ssas is None: + return RewriteResult() + + # Separate control and target qubits + target_qubits = [] + ctrl_qubits = [] + for i in range(len(qubit_idx_ssas)): + if (i % 2) == 0: + ctrl_qubits.append(qubit_idx_ssas[i]) + else: + target_qubits.append(qubit_idx_ssas[i]) + + target_qubits = tuple(target_qubits) + ctrl_qubits = tuple(ctrl_qubits) + + supported_gate_mapping = { + op.stmts.X: stim.CX, + op.stmts.Y: stim.CY, + op.stmts.Z: stim.CZ, + } + + stim_gate = supported_gate_mapping.get(type(ctrl_op_target_gate)) + if stim_gate is None: + return RewriteResult() + + stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits) + + if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)): + # have to "reroute" the input of these statements to directly plug in + # to subsequent statements, remove dependency on the current statement + for input_wire, output_wire in zip( + stmt_with_ctrl.inputs, stmt_with_ctrl.results + ): + output_wire.replace_by(input_wire) + + stmt_with_ctrl.replace_by(stim_stmt) + + return RewriteResult(has_done_something=True) + + +def is_measure_result_used( + stmt: ( + qubit.MeasureAndReset + | qubit.MeasureQubit + | qubit.MeasureQubitList + | wire.MeasureAndReset + | wire.Measure + ), +) -> bool: + """ + Check if the result of a measure statement is used in the program. + """ + return bool(stmt.result.uses) diff --git a/src/bloqade/squin/rewrite/wire_identity_elimination.py b/src/bloqade/squin/rewrite/wire_identity_elimination.py new file mode 100644 index 00000000..a9dcc837 --- /dev/null +++ b/src/bloqade/squin/rewrite/wire_identity_elimination.py @@ -0,0 +1,24 @@ +from kirin import ir +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade.squin import wire + + +class SquinWireIdentityElimination(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + """ + Handle the case where an unwrap feeds a wire directly into a wrap, + equivalent to nothing happening/identity operation + + w = unwrap(qubit) + wrap(qubit, w) + """ + if isinstance(node, wire.Wrap): + wire_origin_stmt = node.wire.owner + if isinstance(wire_origin_stmt, wire.Unwrap): + node.delete() # get rid of wrap + wire_origin_stmt.delete() # get rid of the unwrap + return RewriteResult(has_done_something=True) + + return RewriteResult() diff --git a/src/bloqade/squin/rewrite/wire_to_stim.py b/src/bloqade/squin/rewrite/wire_to_stim.py new file mode 100644 index 00000000..82971c86 --- /dev/null +++ b/src/bloqade/squin/rewrite/wire_to_stim.py @@ -0,0 +1,73 @@ +from kirin import ir +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import stim +from bloqade.squin import op, wire +from bloqade.squin.rewrite.wrap_analysis import AddressAttribute +from bloqade.squin.rewrite.stim_rewrite_util import ( + SQUIN_STIM_GATE_MAPPING, + rewrite_Control, + insert_qubit_idx_from_address, + insert_qubit_idx_from_wire_ssa, +) + + +class SquinWireToStim(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + match node: + case wire.Apply() | wire.Broadcast(): + return self.rewrite_Apply_and_Broadcast(node) + case wire.Reset(): + return self.rewrite_Reset(node) + case _: + return RewriteResult() + + def rewrite_Apply_and_Broadcast( + self, stmt: wire.Apply | wire.Broadcast + ) -> RewriteResult: + + # this is an SSAValue, need it to be the actual operator + applied_op = stmt.operator.owner + assert isinstance(applied_op, op.stmts.Operator) + + if isinstance(applied_op, op.stmts.Control): + return rewrite_Control(stmt) + + stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op)) + if stim_1q_op is None: + return RewriteResult() + + qubit_idx_ssas = insert_qubit_idx_from_wire_ssa( + wire_ssas=stmt.inputs, stmt_to_insert_before=stmt + ) + if qubit_idx_ssas is None: + return RewriteResult() + + stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) + + # Get the wires from the inputs of Apply or Broadcast, + # then put those as the result of the current stmt + # before replacing it entirely + for input_wire, output_wire in zip(stmt.inputs, stmt.results): + output_wire.replace_by(input_wire) + + stmt.replace_by(stim_1q_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_Reset(self, reset_stmt: wire.Reset) -> RewriteResult: + address_attr = reset_stmt.wire.hints.get("address") + if address_attr is None: + return RewriteResult() + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=reset_stmt + ) + if qubit_idx_ssas is None: + return RewriteResult() + + stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas) + reset_stmt.replace_by(stim_rz_stmt) + + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/wrap_analysis.py b/src/bloqade/squin/rewrite/wrap_analysis.py new file mode 100644 index 00000000..f4f47c0a --- /dev/null +++ b/src/bloqade/squin/rewrite/wrap_analysis.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass + +from kirin import ir +from kirin.rewrite.abc import RewriteRule, RewriteResult +from kirin.print.printer import Printer + +from bloqade.squin import op, wire +from bloqade.analysis.address import Address +from bloqade.squin.analysis.nsites import Sites + + +@wire.dialect.register +@dataclass +class AddressAttribute(ir.Attribute): + + name = "Address" + address: Address + + def __hash__(self) -> int: + return hash(self.address) + + def print_impl(self, printer: Printer) -> None: + # Can return to implementing this later + printer.print(self.address) + + +@op.dialect.register +@dataclass +class SitesAttribute(ir.Attribute): + + name = "Sites" + sites: Sites + + def __hash__(self) -> int: + return hash(self.sites) + + def print_impl(self, printer: Printer) -> None: + # Can return to implementing this later + printer.print(self.sites) + + +@dataclass +class WrapSquinAnalysis(RewriteRule): + + address_analysis: dict[ir.SSAValue, Address] + op_site_analysis: dict[ir.SSAValue, Sites] + + def wrap(self, value: ir.SSAValue) -> bool: + address_analysis_result = self.address_analysis[value] + op_site_analysis_result = self.op_site_analysis[value] + + if value.hints.get("address") and value.hints.get("sites"): + return False + else: + value.hints["address"] = AddressAttribute(address_analysis_result) + value.hints["sites"] = SitesAttribute(op_site_analysis_result) + + return True + + def rewrite_Block(self, node: ir.Block) -> RewriteResult: + has_done_something = False + for arg in node.args: + if self.wrap(arg): + has_done_something = True + return RewriteResult(has_done_something=has_done_something) + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + has_done_something = False + for result in node.results: + if self.wrap(result): + has_done_something = True + return RewriteResult(has_done_something=has_done_something) diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py new file mode 100644 index 00000000..533963b1 --- /dev/null +++ b/test/squin/stim/stim.py @@ -0,0 +1,547 @@ +from kirin import ir, types +from kirin.passes import Fold +from kirin.dialects import py, func, ilist + +import bloqade.squin.passes as squin_passes +from bloqade import stim, qasm2, squin +from bloqade.analysis import address +from bloqade.stim.emit import EmitStimMain + + +# Taken gratuitously from Kai's unit test +def stim_codegen(mt: ir.Method): + # method should not have any arguments! + emit = EmitStimMain(mt.dialects) + emit.initialize() + emit.run(mt=mt, args=()) + return emit.get_output() + + +def as_int(value: int): + return py.constant.Constant(value=value) + + +def as_float(value: float): + return py.constant.Constant(value=value) + + +def gen_func_from_stmts(stmts, output_type=types.NoneType): + + extended_dialect = ( + squin.groups.wired.add(qasm2.core) + .add(ilist) + .add(squin.qubit) + .add(stim.collapse) + .add(stim.gate) + ) + + block = ir.Block(stmts) + block.args.append_from(types.MethodType[[], types.NoneType], "main") + func_wrapper = func.Function( + sym_name="main", + signature=func.Signature(inputs=(), output=output_type), + body=ir.Region(blocks=block), + ) + + constructed_method = ir.Method( + mod=None, + py_func=None, + sym_name="main", + dialects=extended_dialect, + code=func_wrapper, + arg_names=[], + ) + + return constructed_method + + +def test_qubit_to_stim(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(4)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + (idx2 := as_int(2)), + (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)), + (idx3 := as_int(3)), + (q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)), + # create ilist of qubits + (q_list := ilist.New(values=(q0.result, q1.result, q2.result, q3.result))), + # Broadcast with stim semantics + (h_op := squin.op.stmts.H()), + (app_res := squin.qubit.Broadcast(h_op.result, q_list.result)), # noqa: F841 + # try Apply now + (x_op := squin.op.stmts.X()), + (sub_q_list := ilist.New(values=(q0.result,))), + (squin.qubit.Apply(x_op.result, sub_q_list.result)), + # go for a control gate + (ctrl_op := squin.op.stmts.Control(x_op.result, n_controls=1)), + (sub_q_list2 := ilist.New(values=(q1.result, q3.result))), + (squin.qubit.Apply(ctrl_op.result, sub_q_list2.result)), + # Measure everything out + (meas_res := squin.qubit.MeasureQubitList(q_list.result)), # noqa: F841 + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_passes.SquinToStim(constructed_method.dialects, no_raise=False)( + constructed_method + ) + + constructed_method.print() + + # some problem with stim codegen in terms of + # stim_prog_str = stim_codegen(constructed_method) + # print(stim_prog_str) + + +def test_wire_to_stim(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(4)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + (idx2 := as_int(2)), + (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)), + (idx3 := as_int(3)), + (q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)), + # get wires from qubits + (w0 := squin.wire.Unwrap(qubit=q0.result)), + (w1 := squin.wire.Unwrap(qubit=q1.result)), + (w2 := squin.wire.Unwrap(qubit=q2.result)), + (w3 := squin.wire.Unwrap(qubit=q3.result)), + # try Apply + (op0 := squin.op.stmts.S()), + (app0 := squin.wire.Apply(op0.result, w0.result)), + # try Broadcast + (op1 := squin.op.stmts.H()), + ( + broad0 := squin.wire.Broadcast( + op1.result, app0.results[0], w1.result, w2.result, w3.result + ) + ), + # wrap everything back + (squin.wire.Wrap(broad0.results[0], q0.result)), + (squin.wire.Wrap(broad0.results[1], q1.result)), + (squin.wire.Wrap(broad0.results[2], q2.result)), + (squin.wire.Wrap(broad0.results[3], q3.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) + + constructed_method.print() + + +test_wire_to_stim() + + +def test_wire_1q_singular_apply(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubit out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + # pass the wires through some 1 Qubit operators + (op1 := squin.op.stmts.S()), + (v0 := squin.wire.Apply(op1.result, w0.result)), + ( + squin.wire.Wrap(v0.results[0], q0.result) + ), # for wrap, just free a use for the result SSAval + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + # the fact I return a wire here means DCE will NOT go ahead and + # eliminate all the other wire.Apply stmts + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) + + constructed_method.print() + + +def test_wire_1q(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubit out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + # pass the wires through some 1 Qubit operators + (op1 := squin.op.stmts.S()), + (op2 := squin.op.stmts.H()), + (op3 := squin.op.stmts.Identity(sites=1)), + (op4 := squin.op.stmts.Identity(sites=1)), + (v0 := squin.wire.Apply(op1.result, w0.result)), + (v1 := squin.wire.Apply(op2.result, v0.results[0])), + (v2 := squin.wire.Apply(op3.result, v1.results[0])), + (v3 := squin.wire.Apply(op4.result, v2.results[0])), + ( + squin.wire.Wrap(v3.results[0], q0.result) + ), # for wrap, just free a use for the result SSAval + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + # the fact I return a wire here means DCE will NOT go ahead and + # eliminate all the other wire.Apply stmts + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) + + constructed_method.print() + + +def test_broadcast_wire_1q_application(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(4)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + (idx2 := as_int(2)), + (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)), + (idx3 := as_int(3)), + (q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + (w1 := squin.wire.Unwrap(qubit=q1.result)), + (w2 := squin.wire.Unwrap(qubit=q2.result)), + (w3 := squin.wire.Unwrap(qubit=q3.result)), + # Apply with stim semantics + (h_op := squin.op.stmts.H()), + ( + app_res := squin.wire.Broadcast( + h_op.result, w0.result, w1.result, w2.result, w3.result + ) + ), + # Wrap everything back + (squin.wire.Wrap(app_res.results[0], q0.result)), + (squin.wire.Wrap(app_res.results[1], q1.result)), + (squin.wire.Wrap(app_res.results[2], q2.result)), + (squin.wire.Wrap(app_res.results[3], q3.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) + + constructed_method.print() + + +# before ANY rewrite, aggressively inline everything, then do the rewrite +# for Stim pass, need to call validation , check any invoke + +# Put one codegen test to stim +# finish measurement analysis Friday - if painful, ask help from Kai +# work on other detector rewrite + +# later on lower for loop to repeat + + +def test_broadcast_qubit_1q_application(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(4)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + (idx2 := as_int(2)), + (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)), + (idx3 := as_int(3)), + (q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)), + # create ilist of qubits + (q_list := ilist.New(values=(q0.result, q1.result, q2.result, q3.result))), + # Apply with stim semantics + (h_op := squin.op.stmts.H()), + (app_res := squin.qubit.Broadcast(h_op.result, q_list.result)), # noqa: F841 + # Measure everything out + (meas_res := squin.qubit.MeasureQubitList(q_list.result)), # noqa: F841 + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) + + constructed_method.print() + + +def test_broadcast_control_gate_wire_application(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(4)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + (idx2 := as_int(2)), + (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)), + (idx3 := as_int(3)), + (q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + (w1 := squin.wire.Unwrap(qubit=q1.result)), + (w2 := squin.wire.Unwrap(qubit=q2.result)), + (w3 := squin.wire.Unwrap(qubit=q3.result)), + # Create and apply CX gate + (x_op := squin.op.stmts.X()), + (ctrl_x_op := squin.op.stmts.Control(x_op.result, n_controls=1)), + ( + app_res := squin.wire.Broadcast( + ctrl_x_op.result, w0.result, w1.result, w2.result, w3.result + ) + ), + # measure it all out + (meas_res_0 := squin.wire.Measure(app_res.results[0])), # noqa: F841 + (meas_res_1 := squin.wire.Measure(app_res.results[1])), # noqa: F841 + (meas_res_2 := squin.wire.Measure(app_res.results[2])), # noqa: F841 + (meas_res_3 := squin.wire.Measure(app_res.results[3])), # noqa: F841 + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) + + constructed_method.print() + + +def test_wire_control(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(2)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubis out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + (w1 := squin.wire.Unwrap(qubit=q1.result)), + # set up control gate + (op1 := squin.op.stmts.X()), + (cx := squin.op.stmts.Control(op1.result, n_controls=1)), + (app := squin.wire.Apply(cx.result, w0.result, w1.result)), + # wrap things back + (squin.wire.Wrap(wire=app.results[0], qubit=q0.result)), + (squin.wire.Wrap(wire=app.results[1], qubit=q1.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) + + constructed_method.print() + + +# Measure being depended on, internal replace_by call +# will not be happy but assumption with rewrite is the +# program is in a valid form +def test_wire_measure(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(2)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubis out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + # measure the wires out + (r0 := squin.wire.Measure(w0.result)), + # return ints so DCE doesn't get + # rid of everything + # (ret_none := func.ConstantNone()), + (func.Return(r0)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + rewrite_result = squin_to_stim(constructed_method) + print(rewrite_result) + constructed_method.print() + + +def test_qubit_reset(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # qubit.reset only accepts ilist of qubits + (qlist := ilist.New(values=[q0.result])), + (squin.qubit.Reset(qubits=qlist.result)), + # (squin.qubit.Measure(qubits=qlist.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + rewrite_result = squin_to_stim(constructed_method) + print(rewrite_result) + constructed_method.print() + + +def test_wire_reset(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # get wire + (w0 := squin.wire.Unwrap(q0.result)), + # reset the wire + (squin.wire.Reset(w0.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + rewrite_result = squin_to_stim(constructed_method) + print(rewrite_result) + constructed_method.print() + + +def test_qubit_measure_and_reset(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # qubit.reset only accepts ilist of qubits + (qlist := ilist.New(values=[q0.result])), + (squin.qubit.MeasureAndReset(qlist.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + # analysis_res, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(constructed_method) + # constructed_method.print(analysis=analysis_res.entries) + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + rewrite_result = squin_to_stim(constructed_method) + print(rewrite_result) + constructed_method.print() + + +def test_wire_measure_and_reset(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # get wire out + (w0 := squin.wire.Unwrap(q0.result)), + # qubit.reset only accepts ilist of qubits + (squin.wire.MeasureAndReset(w0.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + fold_pass = Fold(constructed_method.dialects) + fold_pass(constructed_method) + # need to make sure the origin qubit data is properly + # propagated to the new wire that wire.MeasureAndReset spits out + address_res, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( + constructed_method + ) + constructed_method.print(analysis=address_res.entries) + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + rewrite_result = squin_to_stim(constructed_method) + print(rewrite_result) + constructed_method.print()