1- import pytest
21from kirin .dialects import scf
32from kirin .passes .fold import Fold
43from kirin .passes .inline import InlinePass
@@ -106,7 +105,6 @@ def test():
106105 )
107106
108107
109- @pytest .mark .xfail
110108def test_scf_cond_true ():
111109 @squin .kernel
112110 def test ():
@@ -116,23 +114,20 @@ def test():
116114 ms = None
117115 cond = True
118116 if cond :
119- ms = squin .broadcast . measure (q )
117+ ms = squin .measure (q [ 1 ] )
120118 else :
121119 ms = squin .measure (q [0 ])
122120
123121 return ms
124122
125123 InlinePass (test .dialects ).fixpoint (test )
126124 frame , _ = MeasurementIDAnalysis (test .dialects ).run_analysis (test )
127- test .print (analysis = frame .entries )
128125
129- # MeasureIdTuple(data= MeasureIdBool(idx=1), ) should occur twice:
126+ # MeasureIdBool(idx=1) should occur twice:
130127 # First from the measurement in the true branch, then
131128 # the result of the scf.IfElse itself
132129 analysis_results = [
133- val
134- for val in frame .entries .values ()
135- if val == MeasureIdTuple (data = (MeasureIdBool (idx = 1 ),))
130+ val for val in frame .entries .values () if val == MeasureIdBool (idx = 1 )
136131 ]
137132 assert len (analysis_results ) == 2
138133
@@ -147,7 +142,7 @@ def test():
147142 ms = None
148143 cond = False
149144 if cond :
150- ms = squin .broadcast . measure (q )
145+ ms = squin .measure (q [ 1 ] )
151146 else :
152147 ms = squin .measure (q [0 ])
153148
0 commit comments