diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d71e1f9a..e57cf526 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,6 +7,7 @@ repos: - id: check-yaml args: ['--unsafe'] - id: end-of-file-fixer + exclude: .*\.stim$ - id: trailing-whitespace - repo: https://github.com/pycqa/isort rev: 6.0.1 diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 993b97bd..05267fd3 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -2,7 +2,7 @@ from kirin.analysis import const from kirin.dialects import py, scf, func, ilist -from bloqade import qubit +from bloqade import qubit, annotate from .lattice import ( AnyMeasureId, @@ -46,6 +46,20 @@ def measure_qubit_list( return (MeasureIdTuple(data=tuple(measure_id_bools)),) +@annotate.dialect.register(key="measure_id") +class Annotate(interp.MethodTable): + @interp.impl(annotate.stmts.SetObservable) + @interp.impl(annotate.stmts.SetDetector) + def consumes_measurement_results( + self, + interp: MeasurementIDAnalysis, + frame: MeasureIDFrame, + stmt: annotate.stmts.SetObservable | annotate.stmts.SetDetector, + ): + frame.num_measures_at_stmt[stmt] = interp.measure_count + return (NotMeasureId(),) + + @ilist.dialect.register(key="measure_id") class IList(interp.MethodTable): @interp.impl(ilist.New) diff --git a/src/bloqade/annotate/__init__.py b/src/bloqade/annotate/__init__.py new file mode 100644 index 00000000..88efc714 --- /dev/null +++ b/src/bloqade/annotate/__init__.py @@ -0,0 +1,6 @@ +from . import stmts as stmts +from ._dialect import dialect as dialect +from ._interface import ( + set_detector as set_detector, + set_observable as set_observable, +) diff --git a/src/bloqade/annotate/_dialect.py b/src/bloqade/annotate/_dialect.py new file mode 100644 index 00000000..a5cbad13 --- /dev/null +++ b/src/bloqade/annotate/_dialect.py @@ -0,0 +1,3 @@ +from kirin import ir + +dialect = ir.Dialect("squin.annotate") diff --git a/src/bloqade/annotate/_interface.py b/src/bloqade/annotate/_interface.py new file mode 100644 index 00000000..f7dccc1b --- /dev/null +++ b/src/bloqade/annotate/_interface.py @@ -0,0 +1,22 @@ +from typing import Any + +from kirin.dialects import ilist +from kirin.lowering import wraps + +from bloqade.types import MeasurementResult + +from .stmts import SetDetector, SetObservable +from .types import Detector, Observable + + +@wraps(SetDetector) +def set_detector( + measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult], + coordinates: ilist.IList[float | int, Any] | list[float | int], +) -> Detector: ... + + +@wraps(SetObservable) +def set_observable( + measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult], +) -> Observable: ... diff --git a/src/bloqade/annotate/stmts.py b/src/bloqade/annotate/stmts.py new file mode 100644 index 00000000..789b85bb --- /dev/null +++ b/src/bloqade/annotate/stmts.py @@ -0,0 +1,29 @@ +from kirin import ir, types as kirin_types, lowering +from kirin.decl import info, statement +from kirin.dialects import ilist + +from bloqade.types import MeasurementResultType +from bloqade.annotate.types import DetectorType, ObservableType + +from ._dialect import dialect + + +@statement +class ConsumesMeasurementResults(ir.Statement): + traits = frozenset({lowering.FromPythonCall()}) + measurements: ir.SSAValue = info.argument( + ilist.IListType[MeasurementResultType, kirin_types.Any] + ) + + +@statement(dialect=dialect) +class SetDetector(ConsumesMeasurementResults): + coordinates: ir.SSAValue = info.argument( + type=ilist.IListType[kirin_types.Int | kirin_types.Float, kirin_types.Any] + ) + result: ir.ResultValue = info.result(DetectorType) + + +@statement(dialect=dialect) +class SetObservable(ConsumesMeasurementResults): + result: ir.ResultValue = info.result(ObservableType) diff --git a/src/bloqade/annotate/types.py b/src/bloqade/annotate/types.py new file mode 100644 index 00000000..2783aa8c --- /dev/null +++ b/src/bloqade/annotate/types.py @@ -0,0 +1,13 @@ +from kirin import types + + +class Detector: + pass + + +class Observable: + pass + + +DetectorType = types.PyClass(Detector) +ObservableType = types.PyClass(Observable) diff --git a/src/bloqade/squin/__init__.py b/src/bloqade/squin/__init__.py index 3b4e8c88..3a91751d 100644 --- a/src/bloqade/squin/__init__.py +++ b/src/bloqade/squin/__init__.py @@ -3,7 +3,7 @@ noise as noise, analysis as analysis, ) -from .. import qubit as qubit +from .. import qubit as qubit, annotate as annotate from ..qubit import ( reset as reset, qalloc as qalloc, @@ -12,6 +12,7 @@ get_measurement_id as get_measurement_id, ) from .groups import kernel as kernel +from ..annotate import set_detector as set_detector, set_observable as set_observable from .stdlib.simple import ( h as h, s as s, diff --git a/src/bloqade/squin/groups.py b/src/bloqade/squin/groups.py index c23f78fe..a29a2446 100644 --- a/src/bloqade/squin/groups.py +++ b/src/bloqade/squin/groups.py @@ -3,10 +3,10 @@ from kirin.dialects import debug, ilist from . import gate, noise -from .. import qubit +from .. import qubit, annotate -@ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug])) +@ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug, annotate])) def kernel(self): fold_pass = passes.Fold(self) typeinfer_pass = passes.TypeInfer(self) diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index bf73b2c7..334bbde4 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -27,7 +27,7 @@ from bloqade.analysis.measure_id import MeasurementIDAnalysis from bloqade.stim.passes.flatten import Flatten -from ..rewrite.ifs_to_stim import IfToStim +from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim @dataclass @@ -64,6 +64,8 @@ def unsafe_run(self, mt: Method) -> RewriteResult: rewrite_result = ( Chain( Walk(IfToStim(measure_frame=meas_analysis_frame)), + Walk(SetDetectorToStim(measure_id_frame=meas_analysis_frame)), + Walk(SetObservableToStim(measure_id_frame=meas_analysis_frame)), Fixpoint(Walk(DeadCodeElimination())), ) .rewrite(mt.code) diff --git a/src/bloqade/stim/rewrite/__init__.py b/src/bloqade/stim/rewrite/__init__.py index 4b0eb8fe..6b04bdc2 100644 --- a/src/bloqade/stim/rewrite/__init__.py +++ b/src/bloqade/stim/rewrite/__init__.py @@ -3,3 +3,5 @@ from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim from .squin_measure import SquinMeasureToStim as SquinMeasureToStim from .py_constant_to_stim import PyConstantToStim as PyConstantToStim +from .set_detector_to_stim import SetDetectorToStim as SetDetectorToStim +from .set_observable_to_stim import SetObservableToStim as SetObservableToStim diff --git a/src/bloqade/stim/rewrite/get_record_util.py b/src/bloqade/stim/rewrite/get_record_util.py new file mode 100644 index 00000000..aaa28261 --- /dev/null +++ b/src/bloqade/stim/rewrite/get_record_util.py @@ -0,0 +1,24 @@ +from kirin import ir +from kirin.dialects import py + +from bloqade.stim.dialects import auxiliary +from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple + + +def insert_get_records( + node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count_at_stmt: int +): + """ + Insert GetRecord statements before the given node + """ + get_record_ssas = [] + for measure_id_bool in measure_id_tuple.data: + assert isinstance(measure_id_bool, MeasureIdBool) + target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt + idx_stmt = py.constant.Constant(target_rec_idx) + idx_stmt.insert_before(node) + get_record_stmt = auxiliary.GetRecord(idx_stmt.result) + get_record_stmt.insert_before(node) + get_record_ssas.append(get_record_stmt.result) + + return get_record_ssas diff --git a/src/bloqade/stim/rewrite/set_detector_to_stim.py b/src/bloqade/stim/rewrite/set_detector_to_stim.py new file mode 100644 index 00000000..229067a2 --- /dev/null +++ b/src/bloqade/stim/rewrite/set_detector_to_stim.py @@ -0,0 +1,68 @@ +from typing import Iterable +from dataclasses import dataclass + +from kirin import ir +from kirin.dialects.py import Constant +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade.stim.dialects import auxiliary +from bloqade.annotate.stmts import SetDetector +from bloqade.analysis.measure_id import MeasureIDFrame +from bloqade.stim.dialects.auxiliary import Detector +from bloqade.analysis.measure_id.lattice import MeasureIdTuple + +from ..rewrite.get_record_util import insert_get_records + + +@dataclass +class SetDetectorToStim(RewriteRule): + """ + Rewrite SetDetector to GetRecord and Detector in the stim dialect + """ + + measure_id_frame: MeasureIDFrame + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + match node: + case SetDetector(): + return self.rewrite_SetDetector(node) + case _: + return RewriteResult() + + def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult: + + # get coordinates and generate correct consts + coord_ssas = [] + if not isinstance(node.coordinates.owner, Constant): + return RewriteResult() + + coord_values = node.coordinates.owner.value.unwrap() + + if not isinstance(coord_values, Iterable): + return RewriteResult() + + if any(not isinstance(value, (int, float)) for value in coord_values): + return RewriteResult() + + for coord_value in coord_values: + if isinstance(coord_value, float): + coord_stmt = auxiliary.ConstFloat(value=coord_value) + else: # int + coord_stmt = auxiliary.ConstInt(value=coord_value) + coord_ssas.append(coord_stmt.result) + coord_stmt.insert_before(node) + + measure_ids = self.measure_id_frame.entries[node.measurements] + assert isinstance(measure_ids, MeasureIdTuple) + + get_record_list = insert_get_records( + node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node] + ) + + detector_stmt = Detector( + coord=tuple(coord_ssas), targets=tuple(get_record_list) + ) + + node.replace_by(detector_stmt) + + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/stim/rewrite/set_observable_to_stim.py b/src/bloqade/stim/rewrite/set_observable_to_stim.py new file mode 100644 index 00000000..39ac14fe --- /dev/null +++ b/src/bloqade/stim/rewrite/set_observable_to_stim.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass + +from kirin import ir +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade.stim.dialects import auxiliary +from bloqade.annotate.stmts import SetObservable +from bloqade.analysis.measure_id import MeasureIDFrame +from bloqade.stim.dialects.auxiliary import ObservableInclude +from bloqade.analysis.measure_id.lattice import MeasureIdTuple + +from ..rewrite.get_record_util import insert_get_records + + +@dataclass +class SetObservableToStim(RewriteRule): + """ + Rewrite SetObservable to GetRecord and ObservableInclude in the stim dialect + """ + + measure_id_frame: MeasureIDFrame + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + match node: + case SetObservable(): + return self.rewrite_SetObservable(node) + case _: + return RewriteResult() + + def rewrite_SetObservable(self, node: SetObservable) -> RewriteResult: + + # set idx to 0 for now, but this + # should be something that a user can set on their own. + # SetObservable needs to accept an int. + + idx_stmt = auxiliary.ConstInt(value=0) + idx_stmt.insert_before(node) + + measure_ids = self.measure_id_frame.entries[node.measurements] + assert isinstance(measure_ids, MeasureIdTuple) + + get_record_list = insert_get_records( + node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node] + ) + + observable_include_stmt = ObservableInclude( + idx=idx_stmt.result, targets=tuple(get_record_list) + ) + + node.replace_by(observable_include_stmt) + + return RewriteResult(has_done_something=True) diff --git a/test/stim/passes/stim_reference_programs/annotate/broadcast_with_alias.stim b/test/stim/passes/stim_reference_programs/annotate/broadcast_with_alias.stim new file mode 100644 index 00000000..13098f64 --- /dev/null +++ b/test/stim/passes/stim_reference_programs/annotate/broadcast_with_alias.stim @@ -0,0 +1,2 @@ + +X 0 1 \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/annotate/clean_surface_code.stim b/test/stim/passes/stim_reference_programs/annotate/clean_surface_code.stim new file mode 100644 index 00000000..9703ef7d --- /dev/null +++ b/test/stim/passes/stim_reference_programs/annotate/clean_surface_code.stim @@ -0,0 +1,38 @@ + +RZ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 +H 13 14 15 16 +CX 13 14 16 1 3 7 0 2 4 10 11 12 +CX 13 14 16 2 4 8 3 5 7 10 11 12 +CX 13 15 16 0 4 6 1 3 5 9 10 11 +CX 13 15 16 1 5 7 4 6 8 9 10 11 +H 13 14 15 16 +MZ(0.00000000) 13 14 15 16 +MZ(0.00000000) 9 10 11 12 +RZ 13 14 15 16 9 10 11 12 +DETECTOR(0, 0) rec[-4] +DETECTOR(0, 0) rec[-3] +DETECTOR(0, 0) rec[-2] +DETECTOR(0, 0) rec[-1] +H 13 14 15 16 +CX 13 14 16 1 3 7 0 2 4 10 11 12 +CX 13 14 16 2 4 8 3 5 7 10 11 12 +CX 13 15 16 0 4 6 1 3 5 9 10 11 +CX 13 15 16 1 5 7 4 6 8 9 10 11 +H 13 14 15 16 +MZ(0.00000000) 13 14 15 16 +MZ(0.00000000) 9 10 11 12 +RZ 13 14 15 16 9 10 11 12 +DETECTOR(0, 0) rec[-4] rec[-12] +DETECTOR(0, 0) rec[-3] rec[-11] +DETECTOR(0, 0) rec[-2] rec[-10] +DETECTOR(0, 0) rec[-1] rec[-9] +DETECTOR(0, 0) rec[-8] rec[-16] +DETECTOR(0, 0) rec[-7] rec[-15] +DETECTOR(0, 0) rec[-6] rec[-14] +DETECTOR(0, 0) rec[-5] rec[-13] +MZ(0.00000000) 0 1 2 3 4 5 6 7 8 +DETECTOR(0, 0) rec[-9] rec[-8] rec[-13] +DETECTOR(0, 0) rec[-8] rec[-7] rec[-5] rec[-4] rec[-12] +DETECTOR(0, 0) rec[-6] rec[-5] rec[-3] rec[-2] rec[-11] +DETECTOR(0, 0) rec[-2] rec[-1] rec[-10] +OBSERVABLE_INCLUDE(0) rec[-9] rec[-8] rec[-7] \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/annotate/detector_coords_as_args.stim b/test/stim/passes/stim_reference_programs/annotate/detector_coords_as_args.stim new file mode 100644 index 00000000..f7d6290a --- /dev/null +++ b/test/stim/passes/stim_reference_programs/annotate/detector_coords_as_args.stim @@ -0,0 +1,9 @@ + +MZ(0.00000000) 0 1 +DETECTOR(0, 1) rec[-2] rec[-1] +DETECTOR(3, 4) rec[-2] rec[-1] +DETECTOR(0, 5.00000000) rec[-2] rec[-1] +DETECTOR(5.00000000, 3) rec[-2] rec[-1] +DETECTOR(1, 2, 5.00000000) rec[-2] rec[-1] +DETECTOR(1, 2) rec[-2] rec[-1] +DETECTOR(1, 2, 5.00000000) rec[-2] rec[-1] \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/annotate/kernel_base_op.stim b/test/stim/passes/stim_reference_programs/annotate/kernel_base_op.stim new file mode 100644 index 00000000..7aca4024 --- /dev/null +++ b/test/stim/passes/stim_reference_programs/annotate/kernel_base_op.stim @@ -0,0 +1,6 @@ + +H 1 3 2 +MZ(0.00000000) 1 3 2 +DETECTOR(0, 0) rec[-3] +DETECTOR(0, 0) rec[-2] +DETECTOR(0, 0) rec[-1] \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/annotate/linear_program_rewrite.stim b/test/stim/passes/stim_reference_programs/annotate/linear_program_rewrite.stim new file mode 100644 index 00000000..148b64b7 --- /dev/null +++ b/test/stim/passes/stim_reference_programs/annotate/linear_program_rewrite.stim @@ -0,0 +1,11 @@ + +X 0 +Y 1 +Z 2 +CX 0 1 +CX 0 1 2 3 +Z 0 1 2 3 +MZ(0.00000000) 0 1 2 3 +DETECTOR(0.00000000, 0.00000000) rec[-4] rec[-3] +DETECTOR(1.00000000, 0.00000000) rec[-3] rec[-2] +OBSERVABLE_INCLUDE(0) rec[-2] \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/annotate/measure_desugar.stim b/test/stim/passes/stim_reference_programs/annotate/measure_desugar.stim new file mode 100644 index 00000000..9092c649 --- /dev/null +++ b/test/stim/passes/stim_reference_programs/annotate/measure_desugar.stim @@ -0,0 +1,6 @@ + +MZ(0.00000000) 0 +MZ(0.00000000) 0 +MZ(0.00000000) 0 +MZ(0.00000000) 0 +MZ(0.00000000) 0 \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/annotate/nested_for.stim b/test/stim/passes/stim_reference_programs/annotate/nested_for.stim new file mode 100644 index 00000000..6697fd60 --- /dev/null +++ b/test/stim/passes/stim_reference_programs/annotate/nested_for.stim @@ -0,0 +1,7 @@ + +MZ(0.00000000) 0 1 +DETECTOR(0, 0) rec[-2] +DETECTOR(0, 0) rec[-1] +MZ(0.00000000) 0 1 +DETECTOR(0, 0) rec[-2] +DETECTOR(0, 0) rec[-1] \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/annotate/no_kernel_base_op.stim b/test/stim/passes/stim_reference_programs/annotate/no_kernel_base_op.stim new file mode 100644 index 00000000..86a28575 --- /dev/null +++ b/test/stim/passes/stim_reference_programs/annotate/no_kernel_base_op.stim @@ -0,0 +1,6 @@ + +X 0 2 1 +MZ(0.00000000) 0 2 1 +DETECTOR(0, 0) rec[-3] +DETECTOR(0, 0) rec[-2] +DETECTOR(0, 0) rec[-1] \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/annotate/rep_code.stim b/test/stim/passes/stim_reference_programs/annotate/rep_code.stim new file mode 100644 index 00000000..ce74b78f --- /dev/null +++ b/test/stim/passes/stim_reference_programs/annotate/rep_code.stim @@ -0,0 +1,23 @@ + +RZ 0 1 2 3 4 +CX 0 1 2 3 +CX 2 1 4 3 +MZ(0.00000000) 1 3 +DETECTOR(0, 0) rec[-2] +DETECTOR(0, 0) rec[-1] +CX 0 1 2 3 +CX 2 1 4 3 +MZ(0.00000000) 1 3 +DETECTOR(0, 0) rec[-4] rec[-2] +DETECTOR(0, 0) rec[-3] rec[-1] +CX 0 1 2 3 +CX 2 1 4 3 +DEPOLARIZE2(0.01000000) 0 1 2 3 +I_ERROR[loss](0.00100000) 0 1 2 3 4 +MZ(0.00000000) 1 3 +DETECTOR(0, 0) rec[-4] rec[-2] +DETECTOR(0, 0) rec[-3] rec[-1] +MZ(0.00000000) 0 2 4 +DETECTOR(0, 0) rec[-3] rec[-2] rec[-5] +DETECTOR(0, 0) rec[-2] rec[-1] rec[-4] +OBSERVABLE_INCLUDE(0) rec[-1] \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/annotate/set_detector_with_alias.stim b/test/stim/passes/stim_reference_programs/annotate/set_detector_with_alias.stim new file mode 100644 index 00000000..372b8ad3 --- /dev/null +++ b/test/stim/passes/stim_reference_programs/annotate/set_detector_with_alias.stim @@ -0,0 +1,3 @@ + +MZ(0.00000000) 0 1 +DETECTOR(0, 0) rec[-2] rec[-1] \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/annotate/simple_if_rewrite.stim b/test/stim/passes/stim_reference_programs/annotate/simple_if_rewrite.stim new file mode 100644 index 00000000..89aee87c --- /dev/null +++ b/test/stim/passes/stim_reference_programs/annotate/simple_if_rewrite.stim @@ -0,0 +1,10 @@ + +MZ(0.00000000) 0 1 2 3 +CZ rec[-4] 0 +CX rec[-4] 1 rec[-4] 2 rec[-4] 3 +CZ rec[-4] 0 rec[-4] 1 rec[-4] 2 rec[-4] 3 +CX rec[-3] 0 +CY rec[-3] 1 +MZ(0.00000000) 0 1 2 3 +DETECTOR(0.00000000, 0.00000000) rec[-4] rec[-3] +OBSERVABLE_INCLUDE(0) rec[-2] \ No newline at end of file diff --git a/test/stim/passes/test_annotation_to_stim.py b/test/stim/passes/test_annotation_to_stim.py new file mode 100644 index 00000000..5ab36d2f --- /dev/null +++ b/test/stim/passes/test_annotation_to_stim.py @@ -0,0 +1,336 @@ +import os + +from kirin import ir +from kirin.dialects import scf, ilist + +from bloqade import squin +from bloqade.stim.emit import EmitStimMain +from bloqade.stim.passes import SquinToStimPass + + +def codegen(mt: ir.Method): + # method should not have any arguments! + emit = EmitStimMain() + emit.initialize() + emit.run(mt=mt, args=()) + return emit.get_output() + + +def load_reference_program(filename): + path = os.path.join( + os.path.dirname(__file__), "stim_reference_programs", "annotate", filename + ) + with open(path, "r") as f: + return f.read() + + +def test_linear_program_rewrite(): + + @squin.kernel + def main(): + n_qubits = 4 + q = squin.qalloc(n_qubits) + + # do some gates + squin.x(q[0]) + squin.y(q[1]) + squin.z(q[2]) + squin.cx(q[0], q[1]) + # Broadcast control + squin.broadcast.cx(controls=[q[0], q[2]], targets=[q[1], q[3]]) + # broadcast single qubit gate + squin.broadcast.z(q) + + # measure everything out + ms = squin.broadcast.measure(q) + + # use some statements from dialect + squin.set_detector([ms[0], ms[1]], coordinates=[0.0, 0.0]) + squin.set_detector([ms[1], ms[2]], coordinates=[1.0, 0.0]) + + squin.set_observable(measurements=[ms[2]]) + + return + + SquinToStimPass(main.dialects)(main) + + base_stim_prog = load_reference_program("linear_program_rewrite.stim") + + assert base_stim_prog == codegen(main) + + +def test_simple_if_rewrite(): + + @squin.kernel + def main(): + n_qubits = 4 + q = squin.qalloc(n_qubits) + + ms = squin.broadcast.measure(q) + + if ms[0]: + squin.z(q[0]) + squin.broadcast.x([q[1], q[2], q[3]]) + squin.broadcast.z(q) + + if ms[1]: + squin.x(q[0]) + squin.y(q[1]) + + ms1 = squin.broadcast.measure(q) + squin.set_detector([ms1[0], ms1[1]], coordinates=[0.0, 0.0]) + squin.set_observable(measurements=[ms1[2]]) + + return + + SquinToStimPass(main.dialects)(main) + + base_stim_prog = load_reference_program("simple_if_rewrite.stim") + + assert base_stim_prog == codegen(main) + + +def test_if_with_else_rewrite(): + + @squin.kernel + def main(): + n_qubits = 4 + q = squin.qalloc(n_qubits) + + ms = squin.broadcast.measure(q) + + if ms[0]: + squin.z(q[0]) + else: + squin.x(q[0]) + + return + + assert any(isinstance(stmt, scf.IfElse) for stmt in main.code.regions[0].stmts()) + + +def test_nested_if_rewrite(): + + @squin.kernel + def main(): + n_qubits = 4 + q = squin.qalloc(n_qubits) + + ms = squin.broadcast.measure(q) + + if ms[0]: + squin.z(q[0]) + if ms[0]: + squin.x(q[1]) + + return + + assert any(isinstance(stmt, scf.IfElse) for stmt in main.code.regions[0].stmts()) + + +def test_nested_for(): + + @squin.kernel + def main(): + q = squin.qalloc(2) + for i in range(2): + m = squin.broadcast.measure(q) + for j in range(2): + squin.set_detector([m[j]], coordinates=[0, 0]) + + SquinToStimPass(main.dialects)(main) + + base_stim_prog = load_reference_program("nested_for.stim") + + assert base_stim_prog == codegen(main) + + +def test_measure_desugar(): + + @squin.kernel + def main(): + q = squin.qalloc(10) + + pairs = ilist.IList([0, 1, 2, 3]) + + squin.measure(q[pairs[0]]) + for i in range(1): + squin.measure(q[0]) + squin.measure(q[i]) + squin.measure(q[pairs[0]]) + squin.measure(q[pairs[i]]) + + SquinToStimPass(main.dialects)(main) + + base_stim_prog = load_reference_program("measure_desugar.stim") + + assert base_stim_prog == codegen(main) + + +def test_pick_if_else(): + + @squin.kernel + def main(): + q = squin.qalloc(10) + if False: + squin.h(q[0]) + + if True: + squin.h(q[0]) + + SquinToStimPass(main.dialects, no_raise=True)(main) + + assert not any(type(stmt) is scf.IfElse for stmt in main.code.regions[0].stmts()) + + +def test_set_detector_with_alias(): + + @squin.kernel + def main(): + q = squin.qalloc(2) + results = squin.broadcast.measure(q) + results_2 = results + squin.set_detector( + measurements=[results_2[0], results_2[1]], coordinates=[0, 0] + ) + + SquinToStimPass(main.dialects)(main) + + base_stim_prog = load_reference_program("set_detector_with_alias.stim") + + assert base_stim_prog == codegen(main) + + +def test_broadcast_alias(): + + @squin.kernel + def main(): + q = squin.qalloc(2) + q_2 = q + squin.broadcast.x(q_2) + + SquinToStimPass(main.dialects)(main) + + base_stim_prog = load_reference_program("broadcast_with_alias.stim") + + assert base_stim_prog == codegen(main) + + +def test_rep_code(): + + @squin.kernel + def entangle(cx_pairs): + for i in range(len(cx_pairs)): + controls = cx_pairs[i][::2] + targets = cx_pairs[i][1::2] + squin.broadcast.cx(controls, targets) + + @squin.kernel + def rep_code(): + + q = squin.qalloc(5) + data = q[::2] + ancilla = q[1::2] + + # reset everything initially + squin.broadcast.reset(q) + + ## Initial round, entangle data qubits with ancillas. + ## This entanglement will happen again so it's best we + ## save the qubit pairs for reuse. + cx_pair_1 = [data[0], ancilla[0], data[1], ancilla[1]] + cx_pair_2 = [data[1], ancilla[0], data[2], ancilla[1]] + cx_pairs = [cx_pair_1, cx_pair_2] + + entangle(cx_pairs) + + # let's measure the ancillas and set detectors + init_ancilla_meas_res = squin.broadcast.measure(ancilla) + for i in range(len(init_ancilla_meas_res)): + squin.set_detector( + measurements=[init_ancilla_meas_res[i]], coordinates=[0, 0] + ) + + # let's do a standard round now! + entangle(cx_pairs) + round_ancilla_meas_res = squin.broadcast.measure(ancilla) + for i in range(len(init_ancilla_meas_res)): + squin.set_detector( + measurements=[init_ancilla_meas_res[i], round_ancilla_meas_res[i]], + coordinates=[0, 0], + ) + + # Let's make this one a bit noisy (: + entangle(cx_pairs) + + controls = cx_pairs[0][::2] + targets = cx_pairs[0][1::2] + squin.broadcast.depolarize2(p=0.01, controls=controls, targets=targets) + squin.broadcast.qubit_loss(0.001, q) + + new_round_ancilla_meas_res = squin.broadcast.measure(ancilla) + for i in range(len(new_round_ancilla_meas_res)): + squin.set_detector( + measurements=[round_ancilla_meas_res[i], new_round_ancilla_meas_res[i]], + coordinates=[0, 0], + ) + + # finally we want to measure out the data qubits and set final detectors + # The idea is to assert parity of your data qubits with the final round of measurement results + data_meas_res = squin.broadcast.measure(data) + squin.set_detector( + measurements=[ + data_meas_res[0], + data_meas_res[1], + new_round_ancilla_meas_res[0], + ], + coordinates=[0, 0], + ) + squin.set_detector( + measurements=[ + data_meas_res[1], + data_meas_res[2], + new_round_ancilla_meas_res[1], + ], + coordinates=[0, 0], + ) + + # Now we want to dictate a measurement as the observable + squin.set_observable(measurements=[data_meas_res[-1]]) + + SquinToStimPass(rep_code.dialects)(rep_code) + + base_stim_prog = load_reference_program("rep_code.stim") + + assert base_stim_prog == codegen(rep_code) + + +def test_detector_coords_as_args(): + + @squin.kernel + def func(m, x: list): + squin.set_detector(m, coordinates=x) + + @squin.kernel + def main(): + q = squin.qalloc(2) + m = squin.broadcast.measure(q) + + x = 5.0 + y = [3, 4] + z = [1, 2, x] + + squin.set_detector(m, coordinates=[0, 1]) + squin.set_detector(m, coordinates=y) # [3, 4] + squin.set_detector(m, coordinates=[0, x]) # [0, 5.0] + squin.set_detector(m, coordinates=[x, y[0]]) # [5.0, 3] + squin.set_detector(m, coordinates=z) # [1, 2, 5.0] + squin.set_detector(m, coordinates=z[:2]) # [1, 2] + + func(m, z) # [1, 2, 5.0] + + SquinToStimPass(main.dialects)(main) + + base_stim_prog = load_reference_program("detector_coords_as_args.stim") + + assert base_stim_prog == codegen(main) diff --git a/test/stim/passes/test_code_basic_operations.py b/test/stim/passes/test_code_basic_operations.py new file mode 100644 index 00000000..d16db938 --- /dev/null +++ b/test/stim/passes/test_code_basic_operations.py @@ -0,0 +1,100 @@ +""" +The tests here are part of a "base structure" +that comes up in standard QEC operations, chiefly +- Extracting qubits to be ancilla and data qubits +- Applying entangling rounds/operations on the +physical qubits +- Extracting measurements and setting detectors + +They previously helped debug some problems with the +PhysicalAndSquinToStim pass and are included here +""" + +import os +from typing import Any + +from kirin import ir +from kirin.dialects import ilist + +from bloqade import squin +from bloqade.types import Qubit +from bloqade.stim.emit import EmitStimMain +from bloqade.stim.passes import SquinToStimPass + + +def load_stim_reference(filename): + path = os.path.join( + os.path.dirname(__file__), "stim_reference_programs", "annotate", filename + ) + with open(path, "r") as f: + return f.read() + + +def codegen(mt: ir.Method): + # method should not have any arguments! + emit = EmitStimMain() + emit.initialize() + emit.run(mt=mt, args=()) + return emit.get_output() + + +def test_no_kernel_base_op(): + + qids = [0, 2, 1] + + @squin.kernel + def test(): + total_q = squin.qalloc(4) + + def get_qubit(idx: int) -> Qubit: + return total_q[idx] + + # create a subset of qubits + ## use ilist.map to work around + sub_q = ilist.map(get_qubit, qids) + + squin.broadcast.x(sub_q) + m = squin.broadcast.measure(sub_q) + + for i in range(len(m)): + squin.annotate.set_detector(measurements=[m[i]], coordinates=(0, 0)) + + SquinToStimPass(dialects=test.dialects)(test) + + base_stim_prog = load_stim_reference("no_kernel_base_op.stim") + + assert base_stim_prog == codegen(test) + + +def test_kernel_base_op(): + + a_idxs = [1, 3, 2] + + @squin.kernel + def get_a_qubits(q: ilist.IList[Qubit, Any]): + def get_qubit(idx: int) -> Qubit: + return q[idx] + + return ilist.map(get_qubit, a_idxs) + + @squin.kernel + def measure_out(q: ilist.IList[Qubit, Any]): + aq = get_a_qubits(q) + + squin.broadcast.h(aq) + m_a = squin.broadcast.measure(aq) + return m_a + + @squin.kernel + def main(): + qubits = squin.qalloc(6) + mr = measure_out(qubits) + + for i in range(len(mr)): + squin.set_detector(measurements=[mr[i]], coordinates=(0, 0)) + + SquinToStimPass(main.dialects)(main) + + base_stim_prog = load_stim_reference("kernel_base_op.stim") + + assert base_stim_prog == codegen(main) diff --git a/test/test_annotate.py b/test/test_annotate.py new file mode 100644 index 00000000..74974503 --- /dev/null +++ b/test/test_annotate.py @@ -0,0 +1,32 @@ +from kirin import ir + +from bloqade import squin +from bloqade.stim.emit import EmitStimMain +from bloqade.stim.passes import SquinToStimPass + + +def codegen(mt: ir.Method): + # method should not have any arguments! + emit = EmitStimMain() + emit.initialize() + emit.run(mt=mt, args=()) + return emit.get_output() + + +def test_annotate(): + + @squin.kernel + def test(): + qs = squin.qalloc(4) + ms = squin.broadcast.measure(qs) + squin.set_detector([ms[0], ms[1], ms[2]], coordinates=(0, 0)) + squin.set_observable([ms[3]]) + + SquinToStimPass(dialects=test.dialects)(test) + codegen_output = codegen(test) + expected_output = ( + "\nMZ(0.00000000) 0 1 2 3\n" + "DETECTOR(0, 0) rec[-4] rec[-3] rec[-2]\n" + "OBSERVABLE_INCLUDE(0) rec[-1]" + ) + assert codegen_output == expected_output