diff --git a/.gitignore b/.gitignore
index 1ba3a8cc..d06977b1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,8 @@
+# for stdlib code
+demo/msd_utils/noise.py
+skills
+.agents
+
# misc
.DS_Store
.tmp*
diff --git a/Ion.toml b/Ion.toml
index 47a44b70..8924dfc9 100644
--- a/Ion.toml
+++ b/Ion.toml
@@ -1,5 +1,6 @@
[skills]
ion-cli = { type = "local" }
+msd-paper-decoder-alignment = "skills/msd-paper-decoder-alignment"
receiving-code-review = { source = "obra/superpowers/skills/receiving-code-review" }
requesting-code-review = { source = "obra/superpowers/skills/requesting-code-review" }
using-git-worktrees = { source = "obra/superpowers/skills/using-git-worktrees" }
diff --git a/demo/__init__.py b/demo/__init__.py
new file mode 100644
index 00000000..7de4f741
--- /dev/null
+++ b/demo/__init__.py
@@ -0,0 +1 @@
+"""Demo support package for notebook-oriented helpers."""
diff --git a/demo/benchmark_decoders.ipynb b/demo/benchmark_decoders.ipynb
new file mode 100644
index 00000000..7a0b7973
--- /dev/null
+++ b/demo/benchmark_decoders.ipynb
@@ -0,0 +1,4608 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Benchmarking decoders in `bloqade-lanes`\n",
+ "\n",
+ "This notebook shows a user-level workflow for comparing multiple decoders across multiple logical kernels.\n",
+ "\n",
+ "The core idea is:\n",
+ "\n",
+ "1. compile each kernel into a `GeminiLogicalSimulatorTask`\n",
+ "2. run the task with noise to get `detectors` and `observables`\n",
+ "3. run the same task without noise to get a reference logical output\n",
+ "4. evaluate each decoder on each kernel using a list of metric functions\n",
+ "\n",
+ "The sketch in the design notes used a metric signature that only took `(decoder_ctor, error_model, detectors, observables)`. In practice, metrics such as logical error rate are easier to define if the metric also receives a small `context` dictionary containing the noiseless reference observables and a few run settings.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "ee4a834f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from __future__ import annotations\n",
+ "\n",
+ "from collections import Counter\n",
+ "from time import perf_counter\n",
+ "from typing import Any, Callable\n",
+ "import inspect\n",
+ "import os\n",
+ "import sys\n",
+ "import tracemalloc\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import numpy.typing as npt\n",
+ "import stim\n",
+ "\n",
+ "from bloqade import qubit, squin\n",
+ "from bloqade.gemini import logical\n",
+ "from bloqade.lanes import GeminiLogicalSimulator\n",
+ "from bloqade.decoders import (\n",
+ " BaseDecoder,\n",
+ " BeliefFindDecoder,\n",
+ " BpLsdDecoder,\n",
+ " BpOsdDecoder,\n",
+ " MWPFDecoder,\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "42d2f2c7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "DecoderConstructor = Callable[[stim.DetectorErrorModel], BaseDecoder]\n",
+ "MetricFn = Callable[\n",
+ " [DecoderConstructor, stim.DetectorErrorModel, npt.NDArray[np.bool_], npt.NDArray[np.bool_], dict[str, Any]],\n",
+ " list[float],\n",
+ "]\n",
+ "\n",
+ "\n",
+ "def _decoder_name(decoder_ctor: DecoderConstructor) -> str:\n",
+ " return getattr(decoder_ctor, \"__name__\", decoder_ctor.__class__.__name__)\n",
+ "\n",
+ "\n",
+ "def _mode_row(obs: npt.NDArray[np.bool_]) -> npt.NDArray[np.bool_]:\n",
+ " counts = Counter(map(lambda row: tuple(map(int, row)), obs.tolist()))\n",
+ " row, _ = counts.most_common(1)[0]\n",
+ " return np.asarray(row, dtype=bool)\n",
+ "\n",
+ "\n",
+ "def _decode_and_correct(\n",
+ " decoder_ctor: DecoderConstructor,\n",
+ " error_model: stim.DetectorErrorModel,\n",
+ " detectors: npt.NDArray[np.bool_],\n",
+ " observables: npt.NDArray[np.bool_],\n",
+ ") -> tuple[BaseDecoder, npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:\n",
+ " decoder = decoder_ctor(error_model)\n",
+ " flips = np.asarray(decoder.decode(detectors), dtype=bool)\n",
+ " corrected = np.asarray(observables, dtype=bool) ^ flips\n",
+ " return decoder, flips, corrected\n",
+ "\n",
+ "\n",
+ "def benchmark_decoders(\n",
+ " decoders: list[DecoderConstructor],\n",
+ " kernels: list[Callable[..., Any]],\n",
+ " metrics_fns: dict[str, MetricFn],\n",
+ " *,\n",
+ " shots: int = 10_000,\n",
+ " reference_shots: int = 2_000,\n",
+ " with_noise: bool = True,\n",
+ ") -> dict[str, dict[str, list[list[float]]]]:\n",
+ " \"\"\"\n",
+ " Benchmark multiple decoders on multiple kernels.\n",
+ "\n",
+ " Returns:\n",
+ " results[decoder_name][metric_name][kernel_index] = list[float]\n",
+ " \"\"\"\n",
+ " simulator = GeminiLogicalSimulator()\n",
+ " results: dict[str, dict[str, list[list[float]]]] = {\n",
+ " _decoder_name(decoder): {metric_name: [] for metric_name in metrics_fns}\n",
+ " for decoder in decoders\n",
+ " }\n",
+ "\n",
+ " for kernel in kernels:\n",
+ " task = simulator.task(kernel)\n",
+ " noisy_result = task.run(shots, with_noise=with_noise)\n",
+ " ref_result = task.run(reference_shots, with_noise=False)\n",
+ "\n",
+ " detectors = np.asarray(noisy_result.detectors, dtype=bool)\n",
+ " observables = np.asarray(noisy_result.observables, dtype=bool)\n",
+ " reference_observables = np.asarray(ref_result.observables, dtype=bool)\n",
+ "\n",
+ " context = {\n",
+ " \"kernel_name\": getattr(kernel, \"sym_name\", getattr(kernel, \"__name__\", \"kernel\")),\n",
+ " \"shots\": shots,\n",
+ " \"reference_shots\": reference_shots,\n",
+ " \"reference_observables\": reference_observables,\n",
+ " \"target_observables\": _mode_row(reference_observables),\n",
+ " }\n",
+ "\n",
+ " for decoder_ctor in decoders:\n",
+ " decoder_name = _decoder_name(decoder_ctor)\n",
+ " for metric_name, metric_fn in metrics_fns.items():\n",
+ " metric_values = metric_fn(\n",
+ " decoder_ctor,\n",
+ " task.detector_error_model,\n",
+ " detectors,\n",
+ " observables,\n",
+ " context,\n",
+ " )\n",
+ " results[decoder_name][metric_name].append(metric_values)\n",
+ "\n",
+ " return results\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "7bd08b20",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def logical_error_rate_metric(\n",
+ " decoder_ctor: DecoderConstructor,\n",
+ " error_model: stim.DetectorErrorModel,\n",
+ " detectors: npt.NDArray[np.bool_],\n",
+ " observables: npt.NDArray[np.bool_],\n",
+ " context: dict[str, Any],\n",
+ ") -> list[float]:\n",
+ " _, _, corrected = _decode_and_correct(decoder_ctor, error_model, detectors, observables)\n",
+ " target = np.broadcast_to(context[\"target_observables\"], corrected.shape)\n",
+ " failures = np.any(corrected != target, axis=1)\n",
+ " return [float(np.mean(failures))]\n",
+ "\n",
+ "\n",
+ "def physical_error_rate_metric(\n",
+ " decoder_ctor: DecoderConstructor,\n",
+ " error_model: stim.DetectorErrorModel,\n",
+ " detectors: npt.NDArray[np.bool_],\n",
+ " observables: npt.NDArray[np.bool_],\n",
+ " context: dict[str, Any],\n",
+ ") -> list[float]:\n",
+ " # A convenient proxy for physical error rate from run output alone.\n",
+ " return [float(np.mean(detectors))]\n",
+ "\n",
+ "\n",
+ "def latency_metric(\n",
+ " decoder_ctor: DecoderConstructor,\n",
+ " error_model: stim.DetectorErrorModel,\n",
+ " detectors: npt.NDArray[np.bool_],\n",
+ " observables: npt.NDArray[np.bool_],\n",
+ " context: dict[str, Any],\n",
+ ") -> list[float]:\n",
+ " t0 = perf_counter()\n",
+ " decoder = decoder_ctor(error_model)\n",
+ " t1 = perf_counter()\n",
+ " _ = decoder.decode(detectors)\n",
+ " t2 = perf_counter()\n",
+ " return [1e3 * (t1 - t0), 1e3 * (t2 - t1)]\n",
+ "\n",
+ "\n",
+ "def throughput_metric(\n",
+ " decoder_ctor: DecoderConstructor,\n",
+ " error_model: stim.DetectorErrorModel,\n",
+ " detectors: npt.NDArray[np.bool_],\n",
+ " observables: npt.NDArray[np.bool_],\n",
+ " context: dict[str, Any],\n",
+ ") -> list[float]:\n",
+ " decoder = decoder_ctor(error_model)\n",
+ " t0 = perf_counter()\n",
+ " _ = decoder.decode(detectors)\n",
+ " dt = perf_counter() - t0\n",
+ " shots_per_second = float(len(detectors) / dt) if dt > 0 else float(\"inf\")\n",
+ " return [shots_per_second]\n",
+ "\n",
+ "\n",
+ "def robustness_to_model_mismatch_metric(\n",
+ " decoder_ctor: DecoderConstructor,\n",
+ " error_model: stim.DetectorErrorModel,\n",
+ " detectors: npt.NDArray[np.bool_],\n",
+ " observables: npt.NDArray[np.bool_],\n",
+ " context: dict[str, Any],\n",
+ ") -> list[float]:\n",
+ " mismatch_prob = context.get(\"mismatch_prob\", 0.02)\n",
+ " target = np.broadcast_to(context[\"target_observables\"], observables.shape)\n",
+ "\n",
+ " decoder = decoder_ctor(error_model)\n",
+ " clean_flips = np.asarray(decoder.decode(detectors), dtype=bool)\n",
+ " clean_corrected = observables ^ clean_flips\n",
+ " clean_ler = float(np.mean(np.any(clean_corrected != target, axis=1)))\n",
+ "\n",
+ " rng = np.random.default_rng(0)\n",
+ " detector_noise = rng.random(detectors.shape) < mismatch_prob\n",
+ " mismatched_detectors = detectors ^ detector_noise\n",
+ " mismatch_flips = np.asarray(decoder.decode(mismatched_detectors), dtype=bool)\n",
+ " mismatch_corrected = observables ^ mismatch_flips\n",
+ " mismatch_ler = float(np.mean(np.any(mismatch_corrected != target, axis=1)))\n",
+ " return [clean_ler, mismatch_ler, mismatch_ler - clean_ler]\n",
+ "\n",
+ "\n",
+ "def memory_hardware_cost_metric(\n",
+ " decoder_ctor: DecoderConstructor,\n",
+ " error_model: stim.DetectorErrorModel,\n",
+ " detectors: npt.NDArray[np.bool_],\n",
+ " observables: npt.NDArray[np.bool_],\n",
+ " context: dict[str, Any],\n",
+ ") -> list[float]:\n",
+ " tracemalloc.start()\n",
+ " decoder = decoder_ctor(error_model)\n",
+ " _ = decoder.decode(detectors)\n",
+ " current, peak = tracemalloc.get_traced_memory()\n",
+ " tracemalloc.stop()\n",
+ " shallow_size = float(sys.getsizeof(decoder))\n",
+ " return [shallow_size / (1024**2), peak / (1024**2)]\n",
+ "\n",
+ "\n",
+ "INTERPRETABILITY_PRIORS = {\n",
+ " \"BeliefFindDecoder\": (0.65, 0.55),\n",
+ " \"BpLsdDecoder\": (0.55, 0.60),\n",
+ " \"BpOsdDecoder\": (0.45, 0.72),\n",
+ " \"MWPFDecoder\": (0.35, 0.78),\n",
+ " \"TesseractDecoder\": (0.30, 0.85),\n",
+ "}\n",
+ "\n",
+ "\n",
+ "def interpretability_complexity_metric(\n",
+ " decoder_ctor: DecoderConstructor,\n",
+ " error_model: stim.DetectorErrorModel,\n",
+ " detectors: npt.NDArray[np.bool_],\n",
+ " observables: npt.NDArray[np.bool_],\n",
+ " context: dict[str, Any],\n",
+ ") -> list[float]:\n",
+ " name = _decoder_name(decoder_ctor)\n",
+ " if name in INTERPRETABILITY_PRIORS:\n",
+ " interpretability, complexity = INTERPRETABILITY_PRIORS[name]\n",
+ " else:\n",
+ " n_params = len(inspect.signature(decoder_ctor).parameters)\n",
+ " complexity = min(1.0, 0.4 + 0.05 * n_params)\n",
+ " interpretability = max(0.0, 1.0 - complexity)\n",
+ " return [interpretability, complexity]\n",
+ "\n",
+ "\n",
+ "def decoder_confusion_matrix_metric(\n",
+ " decoder_ctor: DecoderConstructor,\n",
+ " error_model: stim.DetectorErrorModel,\n",
+ " detectors: npt.NDArray[np.bool_],\n",
+ " observables: npt.NDArray[np.bool_],\n",
+ " context: dict[str, Any],\n",
+ ") -> list[float]:\n",
+ " target = np.broadcast_to(context[\"target_observables\"], observables.shape)\n",
+ " _, flips, _ = _decode_and_correct(decoder_ctor, error_model, detectors, observables)\n",
+ " actual_error = np.any(observables != target, axis=1)\n",
+ " predicted_error = np.any(flips, axis=1)\n",
+ " tn = float(np.sum((~actual_error) & (~predicted_error)))\n",
+ " fp = float(np.sum((~actual_error) & predicted_error))\n",
+ " fn = float(np.sum(actual_error & (~predicted_error)))\n",
+ " tp = float(np.sum(actual_error & predicted_error))\n",
+ " return [tn, fp, fn, tp]\n",
+ "\n",
+ "\n",
+ "metrics = {\n",
+ " \"logical_error_rate\": logical_error_rate_metric,\n",
+ " \"physical_error_rate\": physical_error_rate_metric,\n",
+ " \"latency\": latency_metric,\n",
+ " \"throughput\": throughput_metric,\n",
+ " \"robustness_to_model_mismatch\": robustness_to_model_mismatch_metric,\n",
+ " \"memory_hardware_cost\": memory_hardware_cost_metric,\n",
+ " \"interpretability_implementation_complexity\": interpretability_complexity_metric,\n",
+ " \"decoder_confusion_matrix\": decoder_confusion_matrix_metric,\n",
+ "}\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4131335d",
+ "metadata": {},
+ "source": [
+ "## Example kernels\n",
+ "\n",
+ "These are intentionally small kernels so the benchmarking loop is easy to run and modify.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "83cd9293",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@logical.kernel(aggressive_unroll=True, verify=True)\n",
+ "def ghz3_kernel():\n",
+ " reg = qubit.qalloc(3)\n",
+ " squin.h(reg[0])\n",
+ " squin.cx(reg[0], reg[1])\n",
+ " squin.cx(reg[1], reg[2])\n",
+ " logical.default_post_processing(reg)\n",
+ "\n",
+ "\n",
+ "@logical.kernel(aggressive_unroll=True, verify=True)\n",
+ "def bell2_kernel():\n",
+ " reg = qubit.qalloc(2)\n",
+ " squin.h(reg[0])\n",
+ " squin.cx(reg[0], reg[1])\n",
+ " logical.default_post_processing(reg)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "b4279281",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'BpLsdDecoder': {'logical_error_rate': [[0.5025]],\n",
+ " 'physical_error_rate': [[0.06722222222222222]],\n",
+ " 'latency': [[0.435041991295293, 30.871625000145286]],\n",
+ " 'throughput': [[64528.097948343835]],\n",
+ " 'robustness_to_model_mismatch': [[0.5025, 0.5545, 0.052000000000000046]],\n",
+ " 'memory_hardware_cost': [[4.57763671875e-05, 0.5877389907836914]],\n",
+ " 'interpretability_implementation_complexity': [[0.55, 0.6]],\n",
+ " 'decoder_confusion_matrix': [[911.0, 31.0, 890.0, 168.0]]},\n",
+ " 'BpOsdDecoder': {'logical_error_rate': [[0.5025]],\n",
+ " 'physical_error_rate': [[0.06722222222222222]],\n",
+ " 'latency': [[0.4953749885316938, 28.884500003186986]],\n",
+ " 'throughput': [[68273.65787464386]],\n",
+ " 'robustness_to_model_mismatch': [[0.5025, 0.5555, 0.05300000000000005]],\n",
+ " 'memory_hardware_cost': [[4.57763671875e-05, 0.5878000259399414]],\n",
+ " 'interpretability_implementation_complexity': [[0.45, 0.72]],\n",
+ " 'decoder_confusion_matrix': [[911.0, 31.0, 892.0, 166.0]]}}"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "decoders = [BpLsdDecoder, BpOsdDecoder]\n",
+ "kernels = [ghz3_kernel]\n",
+ "\n",
+ "benchmark_results = benchmark_decoders(\n",
+ " decoders=decoders,\n",
+ " kernels=kernels,\n",
+ " metrics_fns=metrics,\n",
+ " shots=2_000,\n",
+ " reference_shots=512,\n",
+ ")\n",
+ "\n",
+ "benchmark_results\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "304a9654",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "logical_error_rate\n",
+ "==================\n",
+ "\\nBpLsdDecoder:\n",
+ " ghz3_kernel: [0.5025]\n",
+ "\\nBpOsdDecoder:\n",
+ " ghz3_kernel: [0.5025]\n",
+ "throughput\n",
+ "==========\n",
+ "\\nBpLsdDecoder:\n",
+ " ghz3_kernel: [64528.097948343835]\n",
+ "\\nBpOsdDecoder:\n",
+ " ghz3_kernel: [68273.65787464386]\n"
+ ]
+ }
+ ],
+ "source": [
+ "def print_metric_table(results: dict[str, dict[str, list[list[float]]]], metric_name: str, kernels: list[Callable[..., Any]]):\n",
+ " kernel_names = [getattr(kernel, \"sym_name\", getattr(kernel, \"__name__\", \"kernel\")) for kernel in kernels]\n",
+ " print(metric_name)\n",
+ " print(\"=\" * len(metric_name))\n",
+ " for decoder_name, decoder_results in results.items():\n",
+ " print(f\"\\\\n{decoder_name}:\")\n",
+ " for kernel_name, values in zip(kernel_names, decoder_results[metric_name]):\n",
+ " print(f\" {kernel_name}: {values}\")\n",
+ "\n",
+ "\n",
+ "print_metric_table(benchmark_results, \"logical_error_rate\", kernels)\n",
+ "print_metric_table(benchmark_results, \"throughput\", kernels)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dfcfc649",
+ "metadata": {},
+ "source": [
+ "## Example: custom metric functions\n",
+ "\n",
+ "Users can add their own metrics as long as they follow the `MetricFn` signature used above.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "ffe34e8c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'BpLsdDecoder': {'logical_error_rate': [[0.515]],\n",
+ " 'average_flips': [[0.06166666666666667]]},\n",
+ " 'BpOsdDecoder': {'logical_error_rate': [[0.516]],\n",
+ " 'average_flips': [[0.060333333333333336]]}}"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def average_flips_metric(\n",
+ " decoder_ctor: DecoderConstructor,\n",
+ " error_model: stim.DetectorErrorModel,\n",
+ " detectors: npt.NDArray[np.bool_],\n",
+ " observables: npt.NDArray[np.bool_],\n",
+ " context: dict[str, Any],\n",
+ ") -> list[float]:\n",
+ " _, flips, _ = _decode_and_correct(decoder_ctor, error_model, detectors, observables)\n",
+ " return [float(np.mean(flips))]\n",
+ "\n",
+ "\n",
+ "custom_metrics = {\n",
+ " \"logical_error_rate\": logical_error_rate_metric,\n",
+ " \"average_flips\": average_flips_metric,\n",
+ "}\n",
+ "\n",
+ "custom_results = benchmark_decoders(\n",
+ " decoders=[BpLsdDecoder, BpOsdDecoder],\n",
+ " kernels=[ghz3_kernel],\n",
+ " metrics_fns=custom_metrics,\n",
+ " shots=1_000,\n",
+ " reference_shots=256,\n",
+ ")\n",
+ "\n",
+ "custom_results\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "71f31940",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/var/folders/4t/_g5ztpl96sg8j3_ztt9wgc400000gp/T/ipykernel_98121/3795600238.py:25: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n",
+ " fig.tight_layout()\n"
+ ]
+ },
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def visualize_benchmarks(\n",
+ " decoder_metrics_results: dict[str, dict[str, list[list[float]]]],\n",
+ " decoders_to_vis: list[str] | None = None,\n",
+ " metrics_to_vis: list[str] | None = None,\n",
+ " save_folder_path: str | None = None,\n",
+ "):\n",
+ " \"\"\"\n",
+ " Takes in some collection of metrics results from the decoders, a list of\n",
+ " decoder names to visualize, a list of metric names to visualize, and\n",
+ " either displays or saves a small set of benchmark visualizations.\n",
+ " \"\"\"\n",
+ " selected_decoders = decoders_to_vis or list(decoder_metrics_results.keys())\n",
+ " if not selected_decoders:\n",
+ " raise ValueError(\"No decoders selected for visualization\")\n",
+ "\n",
+ " available_metrics = set()\n",
+ " for decoder_name in selected_decoders:\n",
+ " available_metrics.update(decoder_metrics_results[decoder_name].keys())\n",
+ " selected_metrics = set(metrics_to_vis or available_metrics)\n",
+ "\n",
+ " if save_folder_path is not None:\n",
+ " os.makedirs(save_folder_path, exist_ok=True)\n",
+ "\n",
+ " def _finish(fig: plt.Figure, metric_name: str, vis_name: str):\n",
+ " fig.tight_layout()\n",
+ " if save_folder_path is not None:\n",
+ " filename = f\"metrics_{metric_name}_vis_{vis_name}.png\"\n",
+ " fig.savefig(os.path.join(save_folder_path, filename), dpi=180, bbox_inches=\"tight\")\n",
+ " plt.close(fig)\n",
+ " else:\n",
+ " plt.show()\n",
+ "\n",
+ " if {\"logical_error_rate\", \"physical_error_rate\"}.issubset(selected_metrics):\n",
+ " fig, ax = plt.subplots(figsize=(7, 5))\n",
+ " for decoder_name in selected_decoders:\n",
+ " logical_error_rate = [values[0] for values in decoder_metrics_results[decoder_name][\"logical_error_rate\"]]\n",
+ " physical_error_rate = [values[0] for values in decoder_metrics_results[decoder_name][\"physical_error_rate\"]]\n",
+ " ax.plot(physical_error_rate, logical_error_rate, marker=\"o\", linestyle=\"-\", label=decoder_name)\n",
+ " ax.set_xlabel(\"Physical error rate proxy (detector click density)\")\n",
+ " ax.set_ylabel(\"Logical error rate\")\n",
+ " ax.set_title(\"Logical error rate vs physical error rate\")\n",
+ " ax.grid(True, alpha=0.3)\n",
+ " ax.legend()\n",
+ " _finish(fig, \"logical_error_rate\", \"overlay\")\n",
+ "\n",
+ " if {\"latency\", \"logical_error_rate\"}.issubset(selected_metrics):\n",
+ " fig, ax = plt.subplots(figsize=(7, 5))\n",
+ " for decoder_name in selected_decoders:\n",
+ " logical_error_rate = [values[0] for values in decoder_metrics_results[decoder_name][\"logical_error_rate\"]]\n",
+ " latency = [values[1] if len(values) > 1 else values[0] for values in decoder_metrics_results[decoder_name][\"latency\"]]\n",
+ " ax.scatter(latency, logical_error_rate, s=90, label=decoder_name)\n",
+ " for idx, (x, y) in enumerate(zip(latency, logical_error_rate)):\n",
+ " ax.annotate(str(idx), (x, y), textcoords=\"offset points\", xytext=(4, 4), fontsize=8)\n",
+ " ax.set_xlabel(\"Decode latency (ms)\")\n",
+ " ax.set_ylabel(\"Logical error rate\")\n",
+ " ax.set_title(\"Latency vs logical error rate\")\n",
+ " ax.grid(True, alpha=0.3)\n",
+ " ax.legend()\n",
+ " _finish(fig, \"latency\", \"scatter\")\n",
+ "\n",
+ " if \"decoder_confusion_matrix\" in selected_metrics:\n",
+ " n = len(selected_decoders)\n",
+ " fig, axes = plt.subplots(1, n, figsize=(4 * n, 4), squeeze=False)\n",
+ " axes = axes[0]\n",
+ " image = None\n",
+ " for ax, decoder_name in zip(axes, selected_decoders):\n",
+ " mats = np.asarray(decoder_metrics_results[decoder_name][\"decoder_confusion_matrix\"], dtype=float)\n",
+ " agg = np.sum(mats, axis=0)\n",
+ " cm = np.array([[agg[0], agg[1]], [agg[2], agg[3]]], dtype=float)\n",
+ " image = ax.imshow(cm, cmap=\"Blues\")\n",
+ " ax.set_title(decoder_name)\n",
+ " ax.set_xticks([0, 1], labels=[\"Pred: no error\", \"Pred: error\"])\n",
+ " ax.set_yticks([0, 1], labels=[\"True: no error\", \"True: error\"])\n",
+ " for i in range(2):\n",
+ " for j in range(2):\n",
+ " ax.text(j, i, int(cm[i, j]), ha=\"center\", va=\"center\")\n",
+ " if image is not None:\n",
+ " fig.colorbar(image, ax=axes.ravel().tolist(), shrink=0.8)\n",
+ " fig.suptitle(\"Side-by-side confusion matrices\")\n",
+ " _finish(fig, \"decoder_confusion_matrix\", \"debug\")\n",
+ "\n",
+ "\n",
+ "visualize_benchmarks(\n",
+ " benchmark_results,\n",
+ " decoders_to_vis=[\"BpLsdDecoder\", \"BpOsdDecoder\"],\n",
+ " metrics_to_vis=[\"logical_error_rate\", \"physical_error_rate\", \"latency\", \"decoder_confusion_matrix\"],\n",
+ " save_folder_path=None,\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5a020069",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "620505e3",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "kirin-workspace (3.12.13)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/demo/code_concatenation.ipynb b/demo/code_concatenation.ipynb
new file mode 100644
index 00000000..c0309520
--- /dev/null
+++ b/demo/code_concatenation.ipynb
@@ -0,0 +1,17021 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Code Concatenation with Steane-on-Steane\n",
+ "\n",
+ "This notebook constructs a concatenated Steane code in the current `@logical.kernel` stack by using an **outer** `[[7,1,3]]` Steane code over 7 logical qubits. Each of those 7 logical qubits is then lowered by `GeminiLogicalSimulator` into its own inner Steane block, producing a `[[49,1,9]]` concatenated code at the physical-measurement level.\n",
+ "\n",
+ "A subtle but important point: this is **not** recursive lowering of an already-lowered squin kernel back into a new logical kernel. The current stack does not expose that as a first-class user workflow. Instead, we use the supported path: one outer logical kernel plus custom `m2dets`/`m2obs` matrices that interpret the final 49 physical measurements as inner and outer Steane syndromes and observables."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "from scipy.linalg import block_diag\n",
+ "\n",
+ "from bloqade import qubit, squin\n",
+ "from bloqade.gemini import logical\n",
+ "from bloqade.lanes import GeminiLogicalSimulator\n",
+ "from bloqade.decoders import BpOsdDecoder"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(49, 24) (49, 1)\n"
+ ]
+ }
+ ],
+ "source": [
+ "STEANE_H = np.array(\n",
+ " [\n",
+ " [1, 1, 1, 1, 0, 0, 0],\n",
+ " [0, 1, 1, 0, 1, 1, 0],\n",
+ " [0, 0, 1, 1, 1, 0, 1],\n",
+ " ],\n",
+ " dtype=int,\n",
+ ")\n",
+ "\n",
+ "# Logical Z support used by the current Steane default post-processing path.\n",
+ "STEANE_OBS = np.array([[1, 1, 0, 0, 0, 1, 0]], dtype=int)\n",
+ "\n",
+ "\n",
+ "def concatenated_steane_matrices():\n",
+ " \"\"\"Return measurement-to-detector / measurement-to-observable matrices for\n",
+ " Steane-on-Steane concatenation.\n",
+ "\n",
+ " The 49 final physical measurements are grouped into 7 inner Steane blocks.\n",
+ " We expose:\n",
+ " - 21 inner detectors (3 per inner block)\n",
+ " - 3 outer detectors, built from the 7 inner logical bits\n",
+ " - 1 outer observable\n",
+ " \"\"\"\n",
+ " inner_det = np.asarray(block_diag(*[STEANE_H.T] * 7), dtype=int) # (49, 21)\n",
+ " inner_obs = np.asarray(block_diag(*[STEANE_OBS.T] * 7), dtype=int) # (49, 7)\n",
+ "\n",
+ " outer_det = (inner_obs @ STEANE_H.T) % 2 # (49, 3)\n",
+ " outer_obs = (inner_obs @ STEANE_OBS.T) % 2 # (49, 1)\n",
+ "\n",
+ " m2dets = np.concatenate([inner_det, outer_det], axis=1) % 2\n",
+ " m2obs = outer_obs % 2\n",
+ " return m2dets.tolist(), m2obs.tolist()\n",
+ "\n",
+ "\n",
+ "concat_m2dets, concat_m2obs = concatenated_steane_matrices()\n",
+ "print(np.asarray(concat_m2dets).shape, np.asarray(concat_m2obs).shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Outer logical kernel\n",
+ "\n",
+ "This kernel acts on 7 **logical** qubits. The simulator then lowers each logical qubit to an inner Steane block. The circuit below prepares an outer encoded `|0_L>` using a standard Steane encoder pattern. Because the final observable is a logical Z readout, the noiseless target is deterministic."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "f791961f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from kirin.dialects import ilist"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@logical.kernel(aggressive_unroll=True, verify=True)\n",
+ "def concatenated_steane_memory():\n",
+ " q = qubit.qalloc(7)\n",
+ "\n",
+ " # Steane encoder: input |0> on q[6], with three |+> ancillas.\n",
+ " squin.h(q[1])\n",
+ " squin.h(q[2])\n",
+ " squin.h(q[3])\n",
+ "\n",
+ " squin.cx(q[6], q[5])\n",
+ " squin.cx(q[1], q[0])\n",
+ " squin.cx(q[2], q[4])\n",
+ " squin.cx(q[2], q[0])\n",
+ " squin.cx(q[3], q[5])\n",
+ " squin.cx(q[1], q[5])\n",
+ " squin.cx(q[6], q[4])\n",
+ " squin.cx(q[2], q[6])\n",
+ " squin.cx(q[3], q[4])\n",
+ " squin.cx(q[3], q[0])\n",
+ " squin.cx(q[1], q[6])\n",
+ "\n",
+ " # The task layer will append terminal measurements and annotations from\n",
+ " # concat_m2dets / concat_m2obs.\n",
+ " return"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "measurements shape: (4000, 49)\n",
+ "detectors shape: (4000, 24)\n",
+ "observables shape: (4000, 1)\n",
+ "target observable: [1]\n"
+ ]
+ }
+ ],
+ "source": [
+ "sim = GeminiLogicalSimulator()\n",
+ "task = sim.task(\n",
+ " concatenated_steane_memory,\n",
+ " m2dets=concat_m2dets,\n",
+ " m2obs=concat_m2obs,\n",
+ ")\n",
+ "\n",
+ "shots = 4000\n",
+ "result = task.run(shots, with_noise=True)\n",
+ "result_ideal = task.run(256, with_noise=False)\n",
+ "\n",
+ "measurements = np.asarray(result.measurements, dtype=bool)\n",
+ "detectors = np.asarray(result.detectors, dtype=bool)\n",
+ "observables = np.asarray(result.observables, dtype=bool)\n",
+ "ideal_observables = np.asarray(result_ideal.observables, dtype=bool)\n",
+ "target_obs = ideal_observables[0]\n",
+ "\n",
+ "print('measurements shape:', measurements.shape)\n",
+ "print('detectors shape: ', detectors.shape)\n",
+ "print('observables shape: ', observables.shape)\n",
+ "print('target observable: ', target_obs.astype(int).tolist())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8707c63d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "task."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "93f358b0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "