Skip to content

Commit 8f4a889

Browse files
johnzl-777kaihsin
andcommitted
add ability to track number of measurements at statements of interest (#413)
Every time an scf.IfElse is encountered the total number of measurements that have occurred will be saved and accessible in later rewrites --------- Co-authored-by: Kai-Hsin Wu <[email protected]>
1 parent e0b99db commit 8f4a889

File tree

3 files changed

+127
-7
lines changed

3 files changed

+127
-7
lines changed

src/bloqade/analysis/measure_id/analysis.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
from typing import TypeVar
2+
from dataclasses import field, dataclass
23

34
from kirin import ir, interp
4-
from kirin.analysis import Forward, const
5+
from kirin.analysis import ForwardExtra, const
56
from kirin.analysis.forward import ForwardFrame
67

78
from .lattice import MeasureId, NotMeasureId
89

910

10-
class MeasurementIDAnalysis(Forward[MeasureId]):
11+
@dataclass
12+
class MeasureIDFrame(ForwardFrame[MeasureId]):
13+
num_measures_at_stmt: dict[ir.Statement, int] = field(default_factory=dict)
14+
15+
16+
class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]):
1117

1218
keys = ["measure_id"]
1319
lattice = MeasureId
1420
# for every kind of measurement encountered, increment this
1521
# then use this to generate the negative values for target rec indices
1622
measure_count = 0
1723

24+
def initialize_frame(
25+
self, code: ir.Statement, *, has_parent_access: bool = False
26+
) -> MeasureIDFrame:
27+
return MeasureIDFrame(code, has_parent_access=has_parent_access)
28+
1829
# Still default to bottom,
1930
# but let constants return the softer "NoMeasureId" type from impl
2031
def eval_stmt_fallback(

src/bloqade/analysis/measure_id/impls.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from kirin import types as kirin_types, interp
2+
from kirin.analysis import const
23
from kirin.dialects import py, scf, func, ilist
34

45
from bloqade.squin import wire, qubit
@@ -10,7 +11,7 @@
1011
MeasureIdTuple,
1112
InvalidMeasureId,
1213
)
13-
from .analysis import MeasurementIDAnalysis
14+
from .analysis import MeasureIDFrame, MeasurementIDAnalysis
1415

1516
## Can't do wire right now because of
1617
## unresolved RFC on return type
@@ -152,4 +153,33 @@ def invoke(
152153
# scf, particularly IfElse
153154
@scf.dialect.register(key="measure_id")
154155
class Scf(scf.absint.Methods):
155-
pass
156+
157+
@interp.impl(scf.IfElse)
158+
def if_else(
159+
self,
160+
interp_: MeasurementIDAnalysis,
161+
frame: MeasureIDFrame,
162+
stmt: scf.IfElse,
163+
):
164+
165+
frame.num_measures_at_stmt[stmt] = interp_.measure_count
166+
167+
# rest of the code taken directly from scf.absint.Methods base implementation
168+
169+
if isinstance(hint := stmt.cond.hints.get("const"), const.Value):
170+
if hint.data:
171+
return self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
172+
else:
173+
return self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
174+
then_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
175+
else_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
176+
177+
match (then_results, else_results):
178+
case (interp.ReturnValue(then_value), interp.ReturnValue(else_value)):
179+
return interp.ReturnValue(then_value.join(else_value))
180+
case (interp.ReturnValue(then_value), _):
181+
return then_results
182+
case (_, interp.ReturnValue(else_value)):
183+
return else_results
184+
case _:
185+
return interp_.join_results(then_results, else_results)

test/analysis/measure_id/test_measure_id.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1-
# from kirin.types import AnyType
2-
# from kirin.dialects.ilist import IList
1+
from kirin.passes import HintConst
2+
from kirin.dialects import scf
33

44
from bloqade.squin import op, qubit, kernel
55
from bloqade.analysis.measure_id import MeasurementIDAnalysis
6-
from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple
6+
from bloqade.analysis.measure_id.lattice import (
7+
NotMeasureId,
8+
MeasureIdBool,
9+
MeasureIdTuple,
10+
)
11+
12+
13+
def results_at(kern, block_id, stmt_id):
14+
return kern.code.body.blocks[block_id].stmts.at(stmt_id).results # type: ignore
715

816

917
def test_add():
@@ -29,3 +37,74 @@ def test():
2937
data=tuple([MeasureIdBool(idx=i) for i in range(1, 11)])
3038
)
3139
assert measure_id_tuples[-1] == expected_measure_id_tuple
40+
41+
42+
def test_measure_count_at_if_else():
43+
44+
@kernel
45+
def test():
46+
q = qubit.new(5)
47+
qubit.apply(op.x(), q[2])
48+
ms = qubit.measure(q)
49+
50+
if ms[1]:
51+
qubit.apply(op.x(), q[0])
52+
53+
if ms[3]:
54+
qubit.apply(op.y(), q[1])
55+
56+
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
57+
58+
assert all(
59+
isinstance(stmt, scf.IfElse) and measures_accumulated == 5
60+
for stmt, measures_accumulated in frame.num_measures_at_stmt.items()
61+
)
62+
63+
64+
def test_scf_cond_true():
65+
@kernel
66+
def test():
67+
q = qubit.new(1)
68+
qubit.apply(op.x(), q[2])
69+
70+
ms = None
71+
cond = True
72+
if cond:
73+
ms = qubit.measure(q)
74+
else:
75+
ms = qubit.measure(q[0])
76+
77+
return ms
78+
79+
HintConst(dialects=test.dialects).unsafe_run(test)
80+
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
81+
82+
assert [frame.entries[result] for result in results_at(test, 0, 7)] == [
83+
NotMeasureId(),
84+
MeasureIdTuple((MeasureIdBool(idx=1),)),
85+
]
86+
87+
88+
def test_scf_cond_false():
89+
90+
@kernel
91+
def test():
92+
q = qubit.new(5)
93+
qubit.apply(op.x(), q[2])
94+
95+
ms = None
96+
cond = False
97+
if cond:
98+
ms = qubit.measure(q)
99+
else:
100+
ms = qubit.measure(q[0])
101+
102+
return ms
103+
104+
HintConst(dialects=test.dialects).unsafe_run(test)
105+
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
106+
107+
assert [frame.entries[result] for result in results_at(test, 0, 7)] == [
108+
NotMeasureId(),
109+
MeasureIdBool(idx=1),
110+
]

0 commit comments

Comments
 (0)