Skip to content

Commit 64e6028

Browse files
rafaelhaweinbe58
andauthored
Get observable idx from measurement id analysis (#712)
Adresses #671 Removes the global observable index from squin. This index is still required to emit stim -- here it is extracted from the MeasurementIDAnalysis pass. --------- Co-authored-by: Phillip Weinberg <pweinberg@quera.com>
1 parent 9885999 commit 64e6028

File tree

4 files changed

+19
-13
lines changed

4 files changed

+19
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ authors = [
1111
]
1212
requires-python = ">=3.10"
1313
dependencies = [
14-
"bloqade-decoders~=0.3.0",
14+
"bloqade-decoders~=0.4.0",
1515
"numpy>=1.22.0",
1616
"scipy>=1.13.1",
1717
"kirin-toolchain~=0.22.2",

src/bloqade/stim/rewrite/set_observable_to_stim.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from dataclasses import dataclass
22

33
from kirin import ir
4+
from kirin.dialects import py
45
from kirin.rewrite.abc import RewriteRule, RewriteResult
56

67
from bloqade.analysis.measure_id import MeasureIDFrame
78
from bloqade.stim.dialects.auxiliary import ObservableInclude
8-
from bloqade.analysis.measure_id.lattice import MeasureIdTuple
9+
from bloqade.analysis.measure_id.lattice import ObservableId, MeasureIdTuple
910
from bloqade.decoders.dialects.annotate.stmts import SetObservable
1011

1112
from ..rewrite.get_record_util import insert_get_records
@@ -29,13 +30,18 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
2930
def rewrite_SetObservable(self, node: SetObservable) -> RewriteResult:
3031
measure_ids = self.measure_id_frame.entries[node.measurements]
3132
assert isinstance(measure_ids, MeasureIdTuple)
33+
observable_id = self.measure_id_frame.entries.get(node.result)
34+
if not isinstance(observable_id, ObservableId):
35+
return RewriteResult()
3236

3337
get_record_list = insert_get_records(
3438
node, measure_ids, self.measure_id_frame.num_measures_at_stmt[node]
3539
)
3640

41+
idx_stmt = py.Constant(observable_id.idx)
42+
idx_stmt.insert_before(node)
3743
observable_include_stmt = ObservableInclude(
38-
idx=node.idx, targets=tuple(get_record_list)
44+
idx=idx_stmt.result, targets=tuple(get_record_list)
3945
)
4046

4147
node.replace_by(observable_include_stmt)

test/stim/passes/test_annotation_to_stim.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def main():
5050
squin.set_detector([ms[0], ms[1]], coordinates=[0.0, 0.0])
5151
squin.set_detector([ms[1], ms[2]], coordinates=[1.0, 0.0])
5252

53-
squin.set_observable(measurements=[ms[2]], idx=0)
53+
squin.set_observable(measurements=[ms[2]])
5454

5555
return
5656

@@ -81,7 +81,7 @@ def main():
8181

8282
ms1 = squin.broadcast.measure(q)
8383
squin.set_detector([ms1[0], ms1[1]], coordinates=[0.0, 0.0])
84-
squin.set_observable(measurements=[ms1[2]], idx=0)
84+
squin.set_observable(measurements=[ms1[2]])
8585

8686
return
8787

@@ -342,7 +342,7 @@ def rep_code():
342342
)
343343

344344
# Now we want to dictate a measurement as the observable
345-
squin.set_observable(measurements=[data_meas_res[-1]], idx=0)
345+
squin.set_observable(measurements=[data_meas_res[-1]])
346346

347347
SquinToStimPass(rep_code.dialects)(rep_code)
348348

test/test_annotate.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test():
2323
qs = squin.qalloc(4)
2424
ms = squin.broadcast.measure(qs)
2525
squin.set_detector([ms[0], ms[1], ms[2]], coordinates=(0, 0))
26-
squin.set_observable([ms[3]], 0)
26+
squin.set_observable([ms[3]])
2727

2828
SquinToStimPass(dialects=test.dialects)(test)
2929
codegen_output = codegen(test)
@@ -35,22 +35,22 @@ def test():
3535
assert codegen_output == expected_output
3636

3737

38-
def test_set_observable_with_multiple_idx():
38+
def test_set_observable_with_multiple_observables():
3939

4040
@squin.kernel
4141
def test():
4242
qs = squin.qalloc(4)
4343
ms = squin.broadcast.measure(qs)
44-
squin.set_observable([ms[0], ms[1]], 0)
45-
squin.set_observable([ms[2]], 42)
46-
squin.set_observable([ms[3]], 0)
44+
squin.set_observable([ms[0], ms[1]])
45+
squin.set_observable([ms[2]])
46+
squin.set_observable([ms[3]])
4747

4848
SquinToStimPass(dialects=test.dialects)(test)
4949
codegen_output = codegen(test)
5050
expected_output = (
5151
"MZ(0.00000000) 0 1 2 3\n"
5252
"OBSERVABLE_INCLUDE(0) rec[-4] rec[-3]\n"
53-
"OBSERVABLE_INCLUDE(42) rec[-2]\n"
54-
"OBSERVABLE_INCLUDE(0) rec[-1]"
53+
"OBSERVABLE_INCLUDE(1) rec[-2]\n"
54+
"OBSERVABLE_INCLUDE(2) rec[-1]"
5555
)
5656
assert codegen_output == expected_output

0 commit comments

Comments
 (0)