Skip to content

Commit 32113e5

Browse files
committed
get initial implementation of new intermediate dialect approach
1 parent fd62258 commit 32113e5

File tree

16 files changed

+689
-154
lines changed

16 files changed

+689
-154
lines changed

src/bloqade/analysis/measure_id/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from . import impls as impls
2+
from .lattice import RecId as RecId
23
from .analysis import (
34
MeasureIDFrame as MeasureIDFrame,
45
MeasurementIDAnalysis as MeasurementIDAnalysis,

src/bloqade/analysis/measure_id/analysis.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import field, dataclass
22

33
from kirin import ir
4+
from kirin.interp import StatementResult
45
from kirin.analysis import ForwardExtra
56
from typing_extensions import Self
67
from kirin.analysis.forward import ForwardFrame
@@ -17,8 +18,6 @@ class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]):
1718

1819
keys = ["measure_id"]
1920
lattice = MeasureId
20-
# for every kind of measurement encountered, increment this
21-
# then use this to generate the negative values for target rec indices
2221
measure_count = 0
2322
detector_count = 0
2423
observable_count = 0
@@ -34,12 +33,18 @@ def initialize_frame(
3433
) -> MeasureIDFrame:
3534
return MeasureIDFrame(node, has_parent_access=has_parent_access)
3635

37-
# Still default to bottom,
38-
# but let constants return the softer "NoMeasureId" type from impl
3936
def eval_fallback(
4037
self, frame: ForwardFrame[MeasureId], node: ir.Statement
4138
) -> tuple[MeasureId, ...]:
4239
return tuple(NotMeasureId() for _ in node.results)
4340

41+
def frame_eval(
42+
self, frame: MeasureIDFrame, node: ir.Statement
43+
) -> StatementResult[MeasureId]:
44+
method = self.lookup_registry(frame, node)
45+
if method is not None:
46+
return method(self, frame, node)
47+
return self.eval_fallback(frame, node)
48+
4449
def method_self(self, method: ir.Method) -> MeasureId:
4550
return self.lattice.bottom()

src/bloqade/analysis/measure_id/impls.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44

55
from bloqade import qubit
66
from bloqade.decoders.dialects import annotate
7+
from bloqade.record_idx_helper import (
8+
GetRecIdxFromPredicate,
9+
GetRecIdxFromMeasurement,
10+
dialect as record_idx_helper_dialect,
11+
)
712
from bloqade.gemini.logical.dialects import operations
813

914
from .lattice import (
15+
RecId,
1016
MeasureId,
1117
Predicate,
1218
DetectorId,
@@ -316,3 +322,38 @@ def if_else(
316322
return else_results
317323
case _:
318324
return interp_.join_results(then_results, else_results)
325+
326+
327+
@record_idx_helper_dialect.register(key="measure_id")
328+
class RecordIdxHelperAnalysis(interp.MethodTable):
329+
330+
@interp.impl(GetRecIdxFromMeasurement)
331+
def get_rec_idx_from_measurement(
332+
self,
333+
interp_: MeasurementIDAnalysis,
334+
frame: MeasureIDFrame,
335+
stmt: GetRecIdxFromMeasurement,
336+
):
337+
measurement_id = frame.get(stmt.measurement)
338+
if not isinstance(measurement_id, (RawMeasureId, MeasureIdBool)):
339+
return (InvalidMeasureId(),)
340+
computed_idx = (measurement_id.idx - 1) - interp_.measure_count
341+
predicate = (
342+
measurement_id.predicate
343+
if isinstance(measurement_id, MeasureIdBool)
344+
else None
345+
)
346+
return (RecId(idx=computed_idx, predicate=predicate),)
347+
348+
@interp.impl(GetRecIdxFromPredicate)
349+
def get_rec_idx_from_predicate(
350+
self,
351+
interp_: MeasurementIDAnalysis,
352+
frame: MeasureIDFrame,
353+
stmt: GetRecIdxFromPredicate,
354+
):
355+
measurement_id = frame.get(stmt.predicate_result)
356+
if not isinstance(measurement_id, MeasureIdBool):
357+
return (InvalidMeasureId(),)
358+
computed_idx = (measurement_id.idx - 1) - interp_.measure_count
359+
return (RecId(idx=computed_idx, predicate=measurement_id.predicate),)

src/bloqade/analysis/measure_id/lattice.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@ class MeasureIdBool(ConcreteMeasureId):
9191
predicate: Predicate
9292

9393

94+
@final
95+
@dataclass
96+
class RecId(ConcreteMeasureId):
97+
idx: int
98+
predicate: Predicate | None
99+
100+
94101
@final
95102
@dataclass
96103
class DetectorId(MeasureId):
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .stmts import (
2+
GetRecIdxFromPredicate as GetRecIdxFromPredicate,
3+
GetRecIdxFromMeasurement as GetRecIdxFromMeasurement,
4+
)
5+
from ._dialect import dialect as dialect
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect("record_idx_helper")
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from kirin import ir, types
2+
from kirin.decl import info, statement
3+
4+
from bloqade.types import MeasurementResultType
5+
6+
from ._dialect import dialect
7+
8+
9+
@statement(dialect=dialect)
10+
class GetRecIdxFromMeasurement(ir.Statement):
11+
name = "get_rec_idx_from_measurement"
12+
traits = frozenset({ir.Pure()})
13+
measurement: ir.SSAValue = info.argument(type=MeasurementResultType)
14+
result: ir.ResultValue = info.result(type=types.Int)
15+
16+
17+
@statement(dialect=dialect)
18+
class GetRecIdxFromPredicate(ir.Statement):
19+
name = "get_rec_idx_from_predicate"
20+
traits = frozenset({ir.Pure()})
21+
predicate_result: ir.SSAValue = info.argument(type=types.Bool)
22+
result: ir.ResultValue = info.result(type=types.Int)

src/bloqade/stim/passes/squin_to_stim.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,17 @@
1010
from kirin.ir.method import Method
1111
from kirin.passes.abc import Pass
1212
from kirin.rewrite.abc import RewriteResult
13+
from kirin.passes.hint_const import HintConst
1314

1415
from bloqade.stim.rewrite import (
16+
IfToStimPartial,
1517
PyConstantToStim,
18+
ResolveGetRecIdx,
1619
SquinNoiseToStim,
1720
SquinQubitToStim,
21+
SetDetectorPartial,
1822
SquinMeasureToStim,
23+
SetObservablePartial,
1924
)
2025
from bloqade.squin.rewrite import (
2126
SquinU3ToClifford,
@@ -24,82 +29,79 @@
2429
)
2530
from bloqade.rewrite.passes import CanonicalizeIList
2631
from bloqade.analysis.address import AddressAnalysis
32+
from bloqade.record_idx_helper import dialect as record_idx_helper_dialect
2733
from bloqade.analysis.measure_id import MeasurementIDAnalysis
2834
from bloqade.stim.passes.flatten import Flatten
2935

30-
from ..rewrite import IfToStim, SetDetectorToStim, SetObservableToStim
31-
3236

3337
@dataclass
3438
class SquinToStimPass(Pass):
3539

3640
def unsafe_run(self, mt: Method) -> RewriteResult:
3741

38-
# inline aggressively:
3942
rewrite_result = Flatten(dialects=mt.dialects, no_raise=self.no_raise).fixpoint(
4043
mt
4144
)
4245

43-
# after this the program should be in a state where it is analyzable
44-
# -------------------------------------------------------------------
45-
46-
mia = MeasurementIDAnalysis(dialects=mt.dialects)
47-
meas_analysis_frame, _ = mia.run(mt)
48-
4946
aa = AddressAnalysis(dialects=mt.dialects)
5047
address_analysis_frame, _ = aa.run(mt)
5148

52-
# wrap the address analysis result
5349
rewrite_result = (
5450
Walk(WrapAddressAnalysis(address_analysis=address_analysis_frame.entries))
5551
.rewrite(mt.code)
5652
.join(rewrite_result)
5753
)
5854

59-
# 2. rewrite
60-
## Invoke DCE afterwards to eliminate any GetItems
61-
## that are no longer being used. This allows for
62-
## SquinMeasureToStim to safely eliminate
63-
## unused measure statements.
55+
# --- partial rewrite (before analysis) ---
6456
rewrite_result = (
65-
Chain(
66-
Walk(IfToStim(measure_frame=meas_analysis_frame)),
67-
Walk(SetDetectorToStim(measure_id_frame=meas_analysis_frame)),
68-
Walk(SetObservableToStim(measure_id_frame=meas_analysis_frame)),
69-
Fixpoint(Walk(DeadCodeElimination())),
57+
Walk(
58+
Chain(
59+
SetDetectorPartial(),
60+
SetObservablePartial(),
61+
IfToStimPartial(),
62+
)
7063
)
7164
.rewrite(mt.code)
7265
.join(rewrite_result)
7366
)
7467

75-
# Rewrite the noise statements first.
7668
rewrite_result = Walk(SquinNoiseToStim()).rewrite(mt.code).join(rewrite_result)
77-
78-
# Wrap Rewrite + SquinToStim can happen w/ standard walk
7969
rewrite_result = Walk(SquinU3ToClifford()).rewrite(mt.code).join(rewrite_result)
70+
rewrite_result = Walk(SquinQubitToStim()).rewrite(mt.code).join(rewrite_result)
8071

72+
# --- analysis (produces RecId for GetRecIdxFromMeasurement / GetRecIdxFromPredicate) ---
73+
analysis_dialects = mt.dialects.add(record_idx_helper_dialect)
8174
rewrite_result = (
82-
Walk(
83-
Chain(
84-
SquinQubitToStim(),
85-
SquinMeasureToStim(),
86-
)
75+
HintConst(analysis_dialects, no_raise=self.no_raise)
76+
.unsafe_run(mt)
77+
.join(rewrite_result)
78+
)
79+
mia = MeasurementIDAnalysis(dialects=analysis_dialects)
80+
meas_analysis_frame, _ = mia.run(mt)
81+
82+
# --- post-analysis: resolve helper stmts into direct integer constants ---
83+
rewrite_result = (
84+
Chain(
85+
Walk(ResolveGetRecIdx(measure_id_frame=meas_analysis_frame)),
86+
Fixpoint(Walk(DeadCodeElimination())),
8787
)
8888
.rewrite(mt.code)
8989
.join(rewrite_result)
9090
)
9191

92+
# --- rewrite measures (must stay until after analysis) ---
93+
rewrite_result = (
94+
Walk(SquinMeasureToStim()).rewrite(mt.code).join(rewrite_result)
95+
)
96+
9297
rewrite_result = (
9398
CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise)
9499
.unsafe_run(mt)
95100
.join(rewrite_result)
96101
)
97102

98-
# Convert all PyConsts to Stim Constants
99103
rewrite_result = Walk(PyConstantToStim()).rewrite(mt.code).join(rewrite_result)
100104

101-
# clear up leftover stmts
102-
# - remove any squin.qalloc that's left around
103105
rewrite_result = (
104106
Fixpoint(
105107
Walk(

src/bloqade/stim/rewrite/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from .squin_noise import SquinNoiseToStim as SquinNoiseToStim
33
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
44
from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
5+
from .ifs_to_stim_partial import IfToStimPartial as IfToStimPartial
56
from .py_constant_to_stim import PyConstantToStim as PyConstantToStim
6-
from .set_detector_to_stim import SetDetectorToStim as SetDetectorToStim
7-
from .set_observable_to_stim import SetObservableToStim as SetObservableToStim
7+
from .resolve_get_rec_idx import ResolveGetRecIdx as ResolveGetRecIdx
8+
from .set_detector_partial import SetDetectorPartial as SetDetectorPartial
9+
from .set_observable_partial import SetObservablePartial as SetObservablePartial
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from dataclasses import dataclass
2+
3+
from kirin import ir
4+
from kirin.dialects import scf
5+
from kirin.rewrite.abc import RewriteRule, RewriteResult
6+
7+
from bloqade.squin import gate
8+
from bloqade.squin.rewrite import AddressAttribute
9+
from bloqade.record_idx_helper import GetRecIdxFromPredicate
10+
from bloqade.stim.rewrite.util import insert_qubit_idx_from_address
11+
from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ
12+
from bloqade.stim.dialects.auxiliary import GetRecord
13+
from bloqade.stim.rewrite.ifs_to_stim import IfElseSimplification
14+
15+
PAULI_TO_CONTROLLED = {
16+
gate.stmts.X: stim_CX,
17+
gate.stmts.Y: stim_CY,
18+
gate.stmts.Z: stim_CZ,
19+
}
20+
21+
22+
@dataclass
23+
class IfToStimPartial(IfElseSimplification, RewriteRule):
24+
"""Rewrite measurement-conditioned IfElse using GetRecIdxFromPredicate.
25+
26+
Accepts the Bool condition directly (result of IsOne/IsZero) and creates
27+
GetRecIdxFromPredicate -> GetRecord -> CX/CY/CZ. If the body contains
28+
multiple Pauli gates, splits into multiple controlled gates sharing the
29+
same GetRecord result.
30+
"""
31+
32+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
33+
match node:
34+
case scf.IfElse():
35+
return self.rewrite_IfElse(node)
36+
case _:
37+
return RewriteResult()
38+
39+
def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
40+
if not self.is_rewriteable(stmt):
41+
return RewriteResult()
42+
43+
*body_stmts, _ = stmt.then_body.stmts()
44+
if not body_stmts:
45+
return RewriteResult()
46+
47+
idx_from_predicate_calc = GetRecIdxFromPredicate(predicate_result=stmt.cond)
48+
idx_from_predicate_calc.insert_before(stmt)
49+
50+
get_record_stmt = GetRecord(id=idx_from_predicate_calc.result)
51+
get_record_stmt.insert_before(stmt)
52+
53+
for body_stmt in body_stmts:
54+
address_attr = body_stmt.qubits.hints.get("address")
55+
if address_attr is None:
56+
return RewriteResult()
57+
assert isinstance(address_attr, AddressAttribute)
58+
59+
qubit_idx_ssas = insert_qubit_idx_from_address(
60+
address=address_attr, stmt_to_insert_before=stmt
61+
)
62+
if qubit_idx_ssas is None:
63+
return RewriteResult()
64+
65+
stim_gate_cls = PAULI_TO_CONTROLLED[type(body_stmt)]
66+
stim_stmt = stim_gate_cls(
67+
targets=tuple(qubit_idx_ssas),
68+
controls=(get_record_stmt.result,) * len(qubit_idx_ssas),
69+
)
70+
stim_stmt.insert_before(stmt)
71+
72+
stmt.delete()
73+
74+
return RewriteResult(has_done_something=True)

0 commit comments

Comments
 (0)