@@ -17,6 +17,16 @@ def results_at(kern, block_id, stmt_id):
1717 return kern .code .body .blocks [block_id ].stmts .at (stmt_id ).results # type: ignore
1818
1919
20+ def results_of_variables (kernel , variable_names ):
21+ results = {}
22+ for stmt in kernel .callable_region .stmts ():
23+ for result in stmt .results :
24+ if result .name in variable_names :
25+ results [result .name ] = result
26+
27+ return results
28+
29+
2030def test_add ():
2131 @squin .kernel
2232 def test ():
@@ -185,9 +195,6 @@ def test(cond: bool):
185195 assert list (frame .entries .values ())[- 2 :] == [AnyMeasureId (), AnyMeasureId ()]
186196
187197
188- test_scf_cond_unknown ()
189-
190-
191198def test_slice ():
192199 @squin .kernel
193200 def test ():
@@ -204,18 +211,20 @@ def test():
204211 Flatten (test .dialects ).fixpoint (test )
205212 frame , _ = MeasurementIDAnalysis (test .dialects ).run_analysis (test )
206213
214+ results = results_of_variables (test , ("msi" , "msi2" , "ms_final" ))
215+
207216 # This is an assertion against `msi` NOT the initial list of measurements
208- assert [ frame .entries [ result ] for result in results_at ( test , 0 , 11 )] == [
209- MeasureIdTuple ( data = tuple (list (MeasureIdBool (idx = i ) for i in range (2 , 7 ) )))
210- ]
217+ assert frame .get ( results [ "msi" ]) == MeasureIdTuple (
218+ data = tuple (list (MeasureIdBool (idx = i ) for i in range (2 , 7 )))
219+ )
211220 # msi2
212- assert [ frame .entries [ result ] for result in results_at ( test , 0 , 12 )] == [
213- MeasureIdTuple ( data = tuple (list (MeasureIdBool (idx = i ) for i in range (3 , 7 ) )))
214- ]
221+ assert frame .get ( results [ "msi2" ]) == MeasureIdTuple (
222+ data = tuple (list (MeasureIdBool (idx = i ) for i in range (3 , 7 )))
223+ )
215224 # ms_final
216- assert [ frame .entries [ result ] for result in results_at ( test , 0 , 14 )] == [
217- MeasureIdTuple ( data = (MeasureIdBool (idx = 3 ), MeasureIdBool (idx = 5 ) ))
218- ]
225+ assert frame .get ( results [ "ms_final" ]) == MeasureIdTuple (
226+ data = (MeasureIdBool (idx = 3 ), MeasureIdBool (idx = 5 ))
227+ )
219228
220229
221230def test_getitem_no_hint ():
0 commit comments