Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/bloqade/analysis/measure_id/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kirin.analysis import const
from kirin.dialects import py, scf, func, ilist

from bloqade import qubit
from bloqade import qubit, annotate

from .lattice import (
AnyMeasureId,
Expand Down Expand Up @@ -46,6 +46,20 @@ def measure_qubit_list(
return (MeasureIdTuple(data=tuple(measure_id_bools)),)


@annotate.dialect.register(key="measure_id")
class Annotate(interp.MethodTable):
@interp.impl(annotate.stmts.SetObservable)
@interp.impl(annotate.stmts.SetDetector)
def consumes_measurement_results(
self,
interp: MeasurementIDAnalysis,
frame: MeasureIDFrame,
stmt: annotate.stmts.SetObservable | annotate.stmts.SetDetector,
):
frame.num_measures_at_stmt[stmt] = interp.measure_count
return (NotMeasureId(),)


@ilist.dialect.register(key="measure_id")
class IList(interp.MethodTable):
@interp.impl(ilist.New)
Expand Down
6 changes: 6 additions & 0 deletions src/bloqade/annotate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import stmts as stmts
from ._dialect import dialect as dialect
from ._interface import (
set_detector as set_detector,
set_observable as set_observable,
)
3 changes: 3 additions & 0 deletions src/bloqade/annotate/_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from kirin import ir

dialect = ir.Dialect("squin.annotate")
21 changes: 21 additions & 0 deletions src/bloqade/annotate/_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any

from kirin.dialects import ilist
from kirin.lowering import wraps

from bloqade.types import MeasurementResult

from .stmts import SetDetector, SetObservable


@wraps(SetDetector)
def set_detector(
measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
coordinates: tuple[float | int, ...],
) -> None: ...


@wraps(SetObservable)
def set_observable(
measurements: ilist.IList[MeasurementResult, Any] | list[MeasurementResult],
) -> None: ...
27 changes: 27 additions & 0 deletions src/bloqade/annotate/stmts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from kirin import ir, types as kirin_types, lowering
from kirin.decl import info, statement
from kirin.dialects import ilist

from bloqade.types import MeasurementResultType

from ._dialect import dialect


@statement
class ConsumesMeasurementResults(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
inputs: ir.SSAValue = info.argument(
ilist.IListType[MeasurementResultType, kirin_types.Any]
)


@statement(dialect=dialect)
class SetDetector(ConsumesMeasurementResults):
coordinates: ir.SSAValue = info.argument(
type=kirin_types.Tuple[kirin_types.Int | kirin_types.Float]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be an IList? Is the length here known?

Judging from the tests, this could be IListType[types.Float, types.Literal(2)].

Do we need to support Int here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this type is wrong.

  1. tuple[int] means it only have one element....
  2. I think the intention here is the concern of sometimes ppl do: (0, 3.2, 5) but I think we should actually not allow it moving forward (error when validate in the future)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I think kirin's Tuple type actually aliases to a vararg tuple.
  2. I agree, so then this type should be changed to ilist.

Copy link
Contributor

@kaihsin kaihsin Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, so the right way to hint it (if you really want to use tuple) is:
types.Tuple[types.Float, ...] which is equivalent to types.Tuple[types.Vararg(types.Float)]

Also, for 2. we can discuss if we want that UX. I don't have strong opinion on that.

Copy link
Contributor Author

@johnzl-777 johnzl-777 Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaihsin why do we want to prohibit variable length tuple for the coordinates? My understanding is they just exist to make visualization easier, it doesn't affect any functionality.

Judging from the tests, this could be IListType[types.Float, types.Literal(2)].

@david-pl I think the only reason I cooked up tests with two integer coordinates is because I don't think I've ever seen a 3D coordinate used. That being said, I imagine there must be more complicated codes where this would be desirable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaihsin's point with the tuple length was that ´Tuple[Int]´ indicates a tuple with only a single element. So the correct typing for variable length here would be ´Tuple[Int, ...]´.

I see, then disregard the comment on the length.

)


@statement(dialect=dialect)
class SetObservable(ConsumesMeasurementResults):
pass
3 changes: 2 additions & 1 deletion src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
noise as noise,
analysis as analysis,
)
from .. import qubit as qubit
from .. import qubit as qubit, annotate as annotate
from ..qubit import (
reset as reset,
qalloc as qalloc,
Expand All @@ -12,6 +12,7 @@
get_measurement_id as get_measurement_id,
)
from .groups import kernel as kernel
from ..annotate import set_detector as set_detector, set_observable as set_observable
from .stdlib.simple import (
h as h,
s as s,
Expand Down
4 changes: 2 additions & 2 deletions src/bloqade/squin/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from kirin.dialects import debug, ilist

from . import gate, noise
from .. import qubit
from .. import qubit, annotate


@ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug]))
@ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug, annotate]))
def kernel(self):
fold_pass = passes.Fold(self)
typeinfer_pass = passes.TypeInfer(self)
Expand Down
4 changes: 3 additions & 1 deletion src/bloqade/stim/passes/squin_to_stim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from bloqade.analysis.measure_id import MeasurementIDAnalysis
from bloqade.stim.passes.flatten import Flatten

from ..rewrite.ifs_to_stim import IfToStim
from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim


@dataclass
Expand Down Expand Up @@ -64,6 +64,8 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
rewrite_result = (
Chain(
Walk(IfToStim(measure_frame=meas_analysis_frame)),
Walk(SetDetectorToStim(measure_id_frame=meas_analysis_frame)),
Walk(SetObservableToStim(measure_id_frame=meas_analysis_frame)),
Fixpoint(Walk(DeadCodeElimination())),
)
.rewrite(mt.code)
Expand Down
2 changes: 2 additions & 0 deletions src/bloqade/stim/rewrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
from .py_constant_to_stim import PyConstantToStim as PyConstantToStim
from .set_detector_to_stim import SetDetectorToStim as SetDetectorToStim
from .set_observable_to_stim import SetObservableToStim as SetObservableToStim
24 changes: 24 additions & 0 deletions src/bloqade/stim/rewrite/get_record_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from kirin import ir
from kirin.dialects import py

from bloqade.stim.dialects import auxiliary
from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple


def insert_get_records(
node: ir.Statement, measure_id_tuple: MeasureIdTuple, meas_count_at_stmt: int
):
"""
Insert GetRecord statements before the given node
"""
get_record_ssas = []
for measure_id_bool in measure_id_tuple.data:
assert isinstance(measure_id_bool, MeasureIdBool)
target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt
idx_stmt = py.constant.Constant(target_rec_idx)
idx_stmt.insert_before(node)
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
get_record_stmt.insert_before(node)
get_record_ssas.append(get_record_stmt.result)

return get_record_ssas
68 changes: 68 additions & 0 deletions src/bloqade/stim/rewrite/set_detector_to_stim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Iterable
from dataclasses import dataclass

from kirin import ir
from kirin.dialects.py import Constant
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade.stim.dialects import auxiliary
from bloqade.annotate.stmts import SetDetector
from bloqade.analysis.measure_id import MeasureIDFrame
from bloqade.stim.dialects.auxiliary import Detector
from bloqade.analysis.measure_id.lattice import MeasureIdTuple

from ..rewrite.get_record_util import insert_get_records


@dataclass
class SetDetectorToStim(RewriteRule):
"""
Rewrite SetDetector to GetRecord and Detector in the stim dialect
"""

measure_id_frame: MeasureIDFrame

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
match node:
case SetDetector():
return self.rewrite_SetDetector(node)
case _:
return RewriteResult()

def rewrite_SetDetector(self, node: SetDetector) -> RewriteResult:

# get coordinates and generate correct consts
coord_ssas = []
if not isinstance(node.coordinates.owner, Constant):
return RewriteResult()

coord_values = node.coordinates.owner.value.unwrap()

if not isinstance(coord_values, Iterable):
return RewriteResult()

if any(not isinstance(value, (int, float)) for value in coord_values):
return RewriteResult()

for coord_value in coord_values:
if isinstance(coord_value, float):
coord_stmt = auxiliary.ConstFloat(value=coord_value)
else: # int
coord_stmt = auxiliary.ConstInt(value=coord_value)
coord_ssas.append(coord_stmt.result)
coord_stmt.insert_before(node)

measure_ids = self.measure_id_frame.entries[node.inputs]
assert isinstance(measure_ids, MeasureIdTuple)

get_record_list = insert_get_records(
node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
)

detector_stmt = Detector(
coord=tuple(coord_ssas), targets=tuple(get_record_list)
)

node.replace_by(detector_stmt)

return RewriteResult(has_done_something=True)
52 changes: 52 additions & 0 deletions src/bloqade/stim/rewrite/set_observable_to_stim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from dataclasses import dataclass

from kirin import ir
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade.stim.dialects import auxiliary
from bloqade.annotate.stmts import SetObservable
from bloqade.analysis.measure_id import MeasureIDFrame
from bloqade.stim.dialects.auxiliary import ObservableInclude
from bloqade.analysis.measure_id.lattice import MeasureIdTuple

from ..rewrite.get_record_util import insert_get_records


@dataclass
class SetObservableToStim(RewriteRule):
"""
Rewrite SetObservable to GetRecord and ObservableInclude in the stim dialect
"""

measure_id_frame: MeasureIDFrame

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
match node:
case SetObservable():
return self.rewrite_SetObservable(node)
case _:
return RewriteResult()

def rewrite_SetObservable(self, node: SetObservable) -> RewriteResult:

# set idx to 0 for now, but this
# should be something that a user can set on their own.
# SetObservable needs to accept an int.

idx_stmt = auxiliary.ConstInt(value=0)
idx_stmt.insert_before(node)

measure_ids = self.measure_id_frame.entries[node.inputs]
assert isinstance(measure_ids, MeasureIdTuple)

get_record_list = insert_get_records(
node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
)

observable_include_stmt = ObservableInclude(
idx=idx_stmt.result, targets=tuple(get_record_list)
)

node.replace_by(observable_include_stmt)

return RewriteResult(has_done_something=True)
32 changes: 32 additions & 0 deletions test/test_annotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from kirin import ir

from bloqade import squin
from bloqade.stim.emit import EmitStimMain
from bloqade.stim.passes import SquinToStimPass


def codegen(mt: ir.Method):
# method should not have any arguments!
emit = EmitStimMain()
emit.initialize()
emit.run(mt=mt, args=())
return emit.get_output()


def test_annotate():

@squin.kernel
def test():
qs = squin.qalloc(4)
ms = squin.broadcast.measure(qs)
squin.set_detector([ms[0], ms[1], ms[2]], coordinates=(0, 0))
squin.set_observable([ms[3]])

SquinToStimPass(dialects=test.dialects)(test)
codegen_output = codegen(test)
expected_output = (
"\nMZ(0.00000000) 0 1 2 3\n"
"DETECTOR(0, 0) rec[-4] rec[-3] rec[-2]\n"
"OBSERVABLE_INCLUDE(0) rec[-1]"
)
assert codegen_output == expected_output