Skip to content

Commit 6b4959e

Browse files
johnzl-777kaihsin
andauthored
Support Slice propogation in MeasurementIDAnalysis (#438)
I was working on a more advanced unit test for SquinToStim when I ran into this problem. If you use slice on a list of measurements the analysis doesn't go through. The solution was to just carry over the logic I figured out for the address analysis. However I cooked up a case where even this is insufficient and figured out the simplest addition to the logic to amend this case. Basically both measure and address have something like the following: ``` # container could be AddressTuple or MeasureIDTuple getitem(container, slice) -> tuple(....) ``` Where tuple is literally a plain Python tuple. This doesn't present any problems until you slice again: ``` getitem(tuple, slice) -> ??? ``` in which case the method table impl is blind to how to handle a raw tuple. I believe the more correct way is that slice should generate ANOTHER container type instead of a naked python tuple. --------- Co-authored-by: Kai-Hsin Wu <[email protected]>
1 parent 8670b8b commit 6b4959e

File tree

3 files changed

+101
-13
lines changed

3 files changed

+101
-13
lines changed

src/bloqade/analysis/measure_id/analysis.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import TypeVar
22
from dataclasses import field, dataclass
33

4-
from kirin import ir, interp
4+
from kirin import ir
55
from kirin.analysis import ForwardExtra, const
66
from kirin.analysis.forward import ForwardFrame
77

@@ -37,20 +37,19 @@ def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
3737
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
3838
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
3939

40-
T = TypeVar("T")
41-
4240
# Xiu-zhe (Roger) Luo came up with this in the address analysis,
43-
# reused here for convenience
41+
# reused here for convenience (now modified to be a bit more graceful)
4442
# TODO: Remove this function once upgrade to kirin 0.18 happens,
4543
# method is built-in to interpreter then
46-
def get_const_value(self, input_type: type[T], value: ir.SSAValue) -> T:
44+
45+
T = TypeVar("T")
46+
47+
def get_const_value(
48+
self, input_type: type[T], value: ir.SSAValue
49+
) -> type[T] | None:
4750
if isinstance(hint := value.hints.get("const"), const.Value):
4851
data = hint.data
4952
if isinstance(data, input_type):
5053
return hint.data
51-
raise interp.InterpreterError(
52-
f"Expected constant value <type = {input_type}>, got {data}"
53-
)
54-
raise interp.InterpreterError(
55-
f"Expected constant value <type = {input_type}>, got {value}"
56-
)
54+
55+
return None

src/bloqade/analysis/measure_id/impls.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,23 @@ class PyIndexing(interp.MethodTable):
103103
def getitem(
104104
self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.GetItem
105105
):
106-
idx = interp.get_const_value(int, stmt.index)
106+
107+
idx_or_slice = interp.get_const_value((int, slice), stmt.index)
108+
if idx_or_slice is None:
109+
return (InvalidMeasureId(),)
110+
111+
# hint = stmt.index.hints.get("const")
112+
# if hint is None or not isinstance(hint, const.Value):
113+
# return (InvalidMeasureId(),)
114+
107115
obj = frame.get(stmt.obj)
108116
if isinstance(obj, MeasureIdTuple):
109-
return (obj.data[idx],)
117+
if isinstance(idx_or_slice, slice):
118+
return (MeasureIdTuple(data=obj.data[idx_or_slice]),)
119+
elif isinstance(idx_or_slice, int):
120+
return (obj.data[idx_or_slice],)
121+
else:
122+
return (InvalidMeasureId(),)
110123
# just propagate these down the line
111124
elif isinstance(obj, (AnyMeasureId, NotMeasureId)):
112125
return (obj,)

test/analysis/measure_id/test_measure_id.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
NotMeasureId,
88
MeasureIdBool,
99
MeasureIdTuple,
10+
InvalidMeasureId,
1011
)
1112

1213

@@ -139,3 +140,78 @@ def test():
139140
NotMeasureId(),
140141
MeasureIdBool(idx=1),
141142
]
143+
144+
145+
def test_slice():
146+
@kernel
147+
def test():
148+
q = qubit.new(6)
149+
qubit.apply(op.x(), q[2])
150+
151+
ms = qubit.measure(q)
152+
msi = ms[1:] # MeasureIdTuple becomes a python tuple
153+
msi2 = msi[1:] # slicing should still work on previous tuple
154+
ms_final = msi2[::2]
155+
156+
return ms_final
157+
158+
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
159+
160+
test.print(analysis=frame.entries)
161+
162+
assert [frame.entries[result] for result in results_at(test, 0, 8)] == [
163+
MeasureIdTuple(data=tuple(list(MeasureIdBool(idx=i) for i in range(2, 7))))
164+
]
165+
assert [frame.entries[result] for result in results_at(test, 0, 10)] == [
166+
MeasureIdTuple(data=tuple(list(MeasureIdBool(idx=i) for i in range(3, 7))))
167+
]
168+
assert [frame.entries[result] for result in results_at(test, 0, 12)] == [
169+
MeasureIdTuple(data=(MeasureIdBool(idx=3), MeasureIdBool(idx=5)))
170+
]
171+
172+
173+
def test_getitem_no_hint():
174+
@kernel
175+
def test(idx):
176+
q = qubit.new(6)
177+
ms = qubit.measure(q)
178+
179+
return ms[idx]
180+
181+
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
182+
183+
assert [frame.entries[result] for result in results_at(test, 0, 3)] == [
184+
InvalidMeasureId(),
185+
]
186+
187+
188+
def test_getitem_invalid_hint():
189+
@kernel
190+
def test():
191+
q = qubit.new(6)
192+
ms = qubit.measure(q)
193+
194+
return ms["x"]
195+
196+
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
197+
198+
assert [frame.entries[result] for result in results_at(test, 0, 4)] == [
199+
InvalidMeasureId()
200+
]
201+
202+
203+
def test_getitem_propagate_invalid_measure():
204+
205+
@kernel
206+
def test():
207+
q = qubit.new(6)
208+
ms = qubit.measure(q)
209+
# this will return an InvalidMeasureId
210+
invalid_ms = ms["x"]
211+
return invalid_ms[0]
212+
213+
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
214+
215+
assert [frame.entries[result] for result in results_at(test, 0, 6)] == [
216+
InvalidMeasureId()
217+
]

0 commit comments

Comments
 (0)