Skip to content

Commit c6573cd

Browse files
committed
updated stim debug
2 parents 88db9e9 + 9913311 commit c6573cd

File tree

4 files changed

+96
-62
lines changed

4 files changed

+96
-62
lines changed

src/bloqade/stim/emit/impls.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
from kirin.emit import EmitStrFrame
21
from kirin.interp import MethodTable, impl
32
from kirin.dialects.debug import Info, dialect
43

5-
from bloqade.stim.emit.stim_str import EmitStimMain
4+
from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame
65

76

87
@dialect.register(key="emit.stim")
98
class EmitStimDebugMethods(MethodTable):
109

1110
@impl(Info)
12-
def info(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: Info):
11+
def info(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: Info):
1312

1413
msg: str = frame.get(stmt.msg)
15-
emit.writeln(frame, f"# {msg}")
14+
frame.write_line(f"# {msg}")
1615

1716
return ()

src/bloqade/stim/emit/stim_str.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import sys
22
from typing import IO, Generic, TypeVar, cast
3-
from contextlib import contextmanager
4-
from dataclasses import field, dataclass
3+
from dataclasses import dataclass
54

65
from kirin import ir, interp
7-
from kirin.idtable import IdTable
86
from kirin.dialects import func
97
from kirin.emit.abc import EmitABC, EmitFrame
108

@@ -14,28 +12,13 @@
1412
@dataclass
1513
class EmitStimFrame(EmitFrame[str], Generic[IO_t]):
1614
io: IO_t = cast(IO_t, sys.stdout)
17-
ssa: IdTable[ir.SSAValue] = field(
18-
default_factory=lambda: IdTable[ir.SSAValue](prefix="ssa_")
19-
)
20-
block: IdTable[ir.Block] = field(
21-
default_factory=lambda: IdTable[ir.Block](prefix="block_")
22-
)
23-
_indent: int = 0
2415

2516
def write(self, value: str) -> None:
2617
self.io.write(value)
2718

2819
def write_line(self, value: str) -> None:
2920
self.write(" " * self._indent + value + "\n")
3021

31-
@contextmanager
32-
def indent(self):
33-
self._indent += 1
34-
try:
35-
yield
36-
finally:
37-
self._indent -= 1
38-
3922

4023
@dataclass
4124
class EmitStimMain(EmitABC[EmitStimFrame, str], Generic[IO_t]):

test/analysis/measure_id/test_measure_id.py

Lines changed: 84 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
import pytest
2-
from kirin.passes import HintConst, inline
31
from kirin.dialects import scf
2+
from kirin.passes.inline import InlinePass
43

54
from bloqade import squin
65
from bloqade.analysis.measure_id import MeasurementIDAnalysis
6+
from bloqade.stim.passes.flatten import Flatten
77
from bloqade.analysis.measure_id.lattice import (
8+
NotMeasureId,
89
MeasureIdBool,
910
MeasureIdTuple,
1011
InvalidMeasureId,
@@ -15,7 +16,16 @@ def results_at(kern, block_id, stmt_id):
1516
return kern.code.body.blocks[block_id].stmts.at(stmt_id).results # type: ignore
1617

1718

18-
@pytest.mark.xfail
19+
def results_of_variables(kernel, variable_names):
20+
results = {}
21+
for stmt in kernel.callable_region.stmts():
22+
for result in stmt.results:
23+
if result.name in variable_names:
24+
results[result.name] = result
25+
26+
return results
27+
28+
1929
def test_add():
2030
@squin.kernel
2131
def test():
@@ -28,6 +38,8 @@ def test():
2838
ml2 = squin.broadcast.measure(ql2)
2939
return ml1 + ml2
3040

41+
Flatten(test.dialects).fixpoint(test)
42+
3143
frame, _ = MeasurementIDAnalysis(test.dialects).run(test)
3244

3345
measure_id_tuples = [
@@ -41,7 +53,6 @@ def test():
4153
assert measure_id_tuples[-1] == expected_measure_id_tuple
4254

4355

44-
@pytest.mark.xfail
4556
def test_measure_alias():
4657

4758
@squin.kernel
@@ -52,28 +63,33 @@ def test():
5263

5364
return ml_alias
5465

66+
Flatten(test.dialects).fixpoint(test)
5567
frame, _ = MeasurementIDAnalysis(test.dialects).run(test)
5668

57-
test.print(analysis=frame.entries)
58-
5969
# Collect MeasureIdTuples
6070
measure_id_tuples = [
6171
value for value in frame.entries.values() if isinstance(value, MeasureIdTuple)
6272
]
6373

64-
# construct expected MeasureIdTuple
65-
expected_measure_id_tuple = MeasureIdTuple(
74+
# construct expected MeasureIdTuples
75+
measure_id_tuple_with_id_bools = MeasureIdTuple(
6676
data=tuple([MeasureIdBool(idx=i) for i in range(1, 6)])
6777
)
78+
measure_id_tuple_with_not_measures = MeasureIdTuple(
79+
data=tuple([NotMeasureId() for _ in range(5)])
80+
)
6881

69-
assert len(measure_id_tuples) == 2
82+
assert len(measure_id_tuples) == 3
83+
# New qubit.new semantics cause a MeasureIdTuple to be generated full of NotMeasureIds because
84+
# qubit.new is actually an ilist.map that invokes single qubit allocation multiple times
85+
# and puts them into an ilist.
86+
assert measure_id_tuples[0] == measure_id_tuple_with_not_measures
7087
assert all(
71-
measure_id_tuple == expected_measure_id_tuple
72-
for measure_id_tuple in measure_id_tuples
88+
measure_id_tuple == measure_id_tuple_with_id_bools
89+
for measure_id_tuple in measure_id_tuples[1:]
7390
)
7491

7592

76-
@pytest.mark.xfail
7793
def test_measure_count_at_if_else():
7894

7995
@squin.kernel
@@ -88,6 +104,7 @@ def test():
88104
if ms[3]:
89105
squin.y(q[1])
90106

107+
Flatten(test.dialects).fixpoint(test)
91108
frame, _ = MeasurementIDAnalysis(test.dialects).run(test)
92109

93110
assert all(
@@ -96,32 +113,29 @@ def test():
96113
)
97114

98115

99-
@pytest.mark.xfail
100116
def test_scf_cond_true():
101117
@squin.kernel
102118
def test():
103-
q = squin.qalloc(1)
119+
q = squin.qalloc(3)
104120
squin.x(q[2])
105121

106122
ms = None
107123
cond = True
108124
if cond:
109-
ms = squin.broadcast.measure(q)
125+
ms = squin.measure(q[1])
110126
else:
111127
ms = squin.measure(q[0])
112128

113129
return ms
114130

115-
HintConst(dialects=test.dialects).unsafe_run(test)
131+
InlinePass(test.dialects).fixpoint(test)
116132
frame, _ = MeasurementIDAnalysis(test.dialects).run(test)
117133

118-
# MeasureIdTuple(data=MeasureIdBool(idx=1),) should occur twice:
134+
# MeasureIdBool(idx=1) should occur twice:
119135
# First from the measurement in the true branch, then
120136
# the result of the scf.IfElse itself
121137
analysis_results = [
122-
val
123-
for val in frame.entries.values()
124-
if val == MeasureIdTuple(data=(MeasureIdBool(idx=1),))
138+
val for val in frame.entries.values() if val == MeasureIdBool(idx=1)
125139
]
126140
assert len(analysis_results) == 2
127141

@@ -136,16 +150,16 @@ def test():
136150
ms = None
137151
cond = False
138152
if cond:
139-
ms = squin.broadcast.measure(q)
153+
ms = squin.measure(q[1])
140154
else:
141-
ms = squin.qubit.measure(q[0])
155+
ms = squin.measure(q[0])
142156

143157
return ms
144158

145-
inline.InlinePass(test.dialects).fixpoint(test)
146-
147-
HintConst(dialects=test.dialects).unsafe_run(test)
159+
# need to preserve the scf.IfElse but need things like qalloc to be inlined
160+
InlinePass(test.dialects).fixpoint(test)
148161
frame, _ = MeasurementIDAnalysis(test.dialects).run(test)
162+
test.print(analysis=frame.entries)
149163

150164
# MeasureIdBool(idx=1) should occur twice:
151165
# First from the measurement in the false branch, then
@@ -156,7 +170,37 @@ def test():
156170
assert len(analysis_results) == 2
157171

158172

159-
@pytest.mark.xfail
173+
def test_scf_cond_unknown():
174+
175+
@squin.kernel
176+
def test(cond: bool):
177+
q = squin.qalloc(5)
178+
squin.x(q[2])
179+
180+
if cond:
181+
ms = squin.broadcast.measure(q)
182+
else:
183+
ms = squin.measure(q[0])
184+
185+
return ms
186+
187+
# We can use Flatten here because the variable condition for the scf.IfElse
188+
# means it cannot be simplified.
189+
Flatten(test.dialects).fixpoint(test)
190+
frame, _ = MeasurementIDAnalysis(test.dialects).run(test)
191+
analysis_results = [
192+
val for val in frame.entries.values() if isinstance(val, MeasureIdTuple)
193+
]
194+
# Both branches of the scf.IfElse should be properly traversed and contain the following
195+
# analysis results.
196+
expected_full_register_measurement = MeasureIdTuple(
197+
data=tuple([MeasureIdBool(idx=i) for i in range(1, 6)])
198+
)
199+
expected_else_measurement = MeasureIdTuple(data=(MeasureIdBool(idx=6),))
200+
assert expected_full_register_measurement in analysis_results
201+
assert expected_else_measurement in analysis_results
202+
203+
160204
def test_slice():
161205
@squin.kernel
162206
def test():
@@ -170,19 +214,23 @@ def test():
170214

171215
return ms_final
172216

217+
Flatten(test.dialects).fixpoint(test)
173218
frame, _ = MeasurementIDAnalysis(test.dialects).run(test)
174219

175-
test.print(analysis=frame.entries)
220+
results = results_of_variables(test, ("msi", "msi2", "ms_final"))
176221

177-
assert [frame.entries[result] for result in results_at(test, 0, 7)] == [
178-
MeasureIdTuple(data=tuple(list(MeasureIdBool(idx=i) for i in range(2, 7))))
179-
]
180-
assert [frame.entries[result] for result in results_at(test, 0, 9)] == [
181-
MeasureIdTuple(data=tuple(list(MeasureIdBool(idx=i) for i in range(3, 7))))
182-
]
183-
assert [frame.entries[result] for result in results_at(test, 0, 11)] == [
184-
MeasureIdTuple(data=(MeasureIdBool(idx=3), MeasureIdBool(idx=5)))
185-
]
222+
# This is an assertion against `msi` NOT the initial list of measurements
223+
assert frame.get(results["msi"]) == MeasureIdTuple(
224+
data=tuple(list(MeasureIdBool(idx=i) for i in range(2, 7)))
225+
)
226+
# msi2
227+
assert frame.get(results["msi2"]) == MeasureIdTuple(
228+
data=tuple(list(MeasureIdBool(idx=i) for i in range(3, 7)))
229+
)
230+
# ms_final
231+
assert frame.get(results["ms_final"]) == MeasureIdTuple(
232+
data=(MeasureIdBool(idx=3), MeasureIdBool(idx=5))
233+
)
186234

187235

188236
def test_getitem_no_hint():
Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import io
2+
13
from kirin.dialects import debug
24

35
from bloqade import stim
4-
5-
from .base import codegen
6+
from bloqade.stim.emit import EmitStimMain
67

78

89
def test_debug():
@@ -12,5 +13,8 @@ def test_debug_main():
1213
debug.info("debug message")
1314

1415
test_debug_main.print()
15-
out = codegen(test_debug_main)
16-
assert out.strip() == "# debug message"
16+
17+
buf = io.StringIO()
18+
stim_emit: EmitStimMain[io.StringIO] = EmitStimMain(dialects=stim.main, io=buf)
19+
stim_emit.run(test_debug_main)
20+
assert buf.getvalue().strip() == "# debug message"

0 commit comments

Comments
 (0)