11from kirin import interp
2- from kirin .analysis import ForwardFrame
2+ from kirin .analysis import ForwardFrame , const
33from kirin .dialects import scf
44
5- from bloqade .analysis .address import Address
5+ from bloqade .analysis .address import Address , ConstResult
66
77from .analysis import FidelityAnalysis
88
@@ -18,17 +18,27 @@ def if_else(
1818 current_gate_fidelities = interp_ .gate_fidelities
1919 current_survival_fidelities = interp_ .qubit_survival_fidelities
2020
21- # TODO: check if the condition is constant and fix the branch in that case
22- # run both branches
21+ address_cond = frame .get (stmt .cond )
22+
23+ # NOTE: if the condition is known at compile time, run specific branch
24+ if isinstance (address_cond , ConstResult ) and isinstance (
25+ const_cond := address_cond .result , const .Value
26+ ):
27+ body = stmt .then_body if const_cond .data else stmt .else_body
28+ with interp_ .new_frame (stmt , has_parent_access = True ) as body_frame :
29+ ret = interp_ .frame_call_region (body_frame , stmt , body , address_cond )
30+ return ret
31+
32+ # NOTE: runtime condition, evaluate both
2333 with interp_ .new_frame (stmt , has_parent_access = True ) as then_frame :
2434 # NOTE: reset fidelities before stepping into the then-body
2535 interp_ .reset_fidelities ()
2636
27- interp_ .frame_call_region (
37+ then_results = interp_ .frame_call_region (
2838 then_frame ,
2939 stmt ,
3040 stmt .then_body ,
31- * ( interp_ . lattice . bottom () for _ in range ( len ( stmt . args ))) ,
41+ address_cond ,
3242 )
3343 then_fids = interp_ .gate_fidelities
3444 then_survival = interp_ .qubit_survival_fidelities
@@ -37,11 +47,11 @@ def if_else(
3747 # NOTE: reset again before stepping into else-body
3848 interp_ .reset_fidelities ()
3949
40- interp_ .frame_call_region (
50+ else_results = interp_ .frame_call_region (
4151 else_frame ,
4252 stmt ,
4353 stmt .else_body ,
44- * ( interp_ . lattice . bottom () for _ in range ( len ( stmt . args ))) ,
54+ address_cond ,
4555 )
4656
4757 else_fids = interp_ .gate_fidelities
@@ -60,3 +70,17 @@ def if_else(
6070 then_survival ,
6171 else_survival ,
6272 )
73+
74+ # TODO: pick the non-return value
75+ if isinstance (then_results , interp .ReturnValue ) and isinstance (
76+ else_results , interp .ReturnValue
77+ ):
78+ return interp .ReturnValue (then_results .value .join (else_results .value ))
79+ elif isinstance (then_results , interp .ReturnValue ):
80+ ret = else_results
81+ elif isinstance (else_results , interp .ReturnValue ):
82+ ret = then_results
83+ else :
84+ ret = interp_ .join_results (then_results , else_results )
85+
86+ return ret
0 commit comments