Skip to content

Commit 3abe25f

Browse files
committed
Fix test_slice on python 3.10
1 parent a540899 commit 3abe25f

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

test/analysis/measure_id/test_measure_id.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ def results_at(kern, block_id, stmt_id):
1717
return kern.code.body.blocks[block_id].stmts.at(stmt_id).results # type: ignore
1818

1919

20+
def results_of_variables(kernel, variable_names):
21+
results = {}
22+
for stmt in kernel.callable_region.stmts():
23+
for result in stmt.results:
24+
if result.name in variable_names:
25+
results[result.name] = result
26+
27+
return results
28+
29+
2030
def test_add():
2131
@squin.kernel
2232
def test():
@@ -185,9 +195,6 @@ def test(cond: bool):
185195
assert list(frame.entries.values())[-2:] == [AnyMeasureId(), AnyMeasureId()]
186196

187197

188-
test_scf_cond_unknown()
189-
190-
191198
def test_slice():
192199
@squin.kernel
193200
def test():
@@ -204,18 +211,20 @@ def test():
204211
Flatten(test.dialects).fixpoint(test)
205212
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
206213

214+
results = results_of_variables(test, ("msi", "msi2", "ms_final"))
215+
207216
# This is an assertion against `msi` NOT the initial list of measurements
208-
assert [frame.entries[result] for result in results_at(test, 0, 11)] == [
209-
MeasureIdTuple(data=tuple(list(MeasureIdBool(idx=i) for i in range(2, 7))))
210-
]
217+
assert frame.get(results["msi"]) == MeasureIdTuple(
218+
data=tuple(list(MeasureIdBool(idx=i) for i in range(2, 7)))
219+
)
211220
# msi2
212-
assert [frame.entries[result] for result in results_at(test, 0, 12)] == [
213-
MeasureIdTuple(data=tuple(list(MeasureIdBool(idx=i) for i in range(3, 7))))
214-
]
221+
assert frame.get(results["msi2"]) == MeasureIdTuple(
222+
data=tuple(list(MeasureIdBool(idx=i) for i in range(3, 7)))
223+
)
215224
# ms_final
216-
assert [frame.entries[result] for result in results_at(test, 0, 14)] == [
217-
MeasureIdTuple(data=(MeasureIdBool(idx=3), MeasureIdBool(idx=5)))
218-
]
225+
assert frame.get(results["ms_final"]) == MeasureIdTuple(
226+
data=(MeasureIdBool(idx=3), MeasureIdBool(idx=5))
227+
)
219228

220229

221230
def test_getitem_no_hint():

0 commit comments

Comments
 (0)