11from collections .abc import Iterable
22
3- from kirin import interp
3+ from kirin import ir , interp
44from kirin .analysis import const
55
66from .stmts import For , Yield , IfElse
@@ -32,35 +32,50 @@ def if_else(
3232 ):
3333 cond = frame .get (stmt .cond )
3434 if isinstance (cond , const .Value ):
35- with interp_ .state .new_frame (interp_ .new_frame (stmt )) as body_frame :
36- body_frame .entries .update (frame .entries )
37- if cond .data :
38- results = interp_ .run_ssacfg_region (body_frame , stmt .then_body )
39- else :
40- results = interp_ .run_ssacfg_region (body_frame , stmt .else_body )
35+ if cond .data :
36+ body = stmt .then_body
37+ else :
38+ body = stmt .else_body
39+ body_frame , ret = self ._prop_const_cond_ifelse (
40+ interp_ , frame , stmt , cond , body
41+ )
42+ frame .entries .update (body_frame .entries )
43+ return ret
44+ else :
45+ then_frame , then_results = self ._prop_const_cond_ifelse (
46+ interp_ , frame , stmt , const .Value (True ), stmt .then_body
47+ )
48+ else_frame , else_results = self ._prop_const_cond_ifelse (
49+ interp_ , frame , stmt , const .Value (False ), stmt .else_body
50+ )
51+ ret = interp_ .join_results (then_results , else_results )
4152
42- if not body_frame .frame_is_not_pure :
53+ if not then_frame . frame_is_not_pure or not else_frame .frame_is_not_pure :
4354 frame .should_be_pure .add (stmt )
44- else :
45- with interp_ .state .new_frame (interp_ .new_frame (stmt )) as then_body_frame :
46- then_body_frame .entries .update (frame .entries )
47- then_results = interp_ .run_ssacfg_region (
48- then_body_frame , stmt .then_body
49- )
5055
51- with interp_ . state . new_frame ( interp_ . new_frame ( stmt )) as else_body_frame :
52- else_body_frame . entries . update ( frame . entries )
53- else_results = interp_ . run_ssacfg_region (
54- else_body_frame , stmt . else_body
55- )
56- results = interp_ . join_results ( then_results , else_results )
56+ # NOTE: then_frame and else_frame do not change
57+ # parent frame variables value except cond
58+ frame . entries . update ( then_frame . entries )
59+ frame . entries . update ( else_frame . entries )
60+ frame . set ( stmt . cond , cond )
61+ return ret
5762
58- if (
59- not then_body_frame .frame_is_not_pure
60- or not else_body_frame .frame_is_not_pure
61- ):
62- frame .should_be_pure .add (stmt )
63- return results
63+ def _prop_const_cond_ifelse (
64+ self ,
65+ interp_ : const .Propagate ,
66+ frame : const .Frame ,
67+ stmt : IfElse ,
68+ cond : const .Value ,
69+ body : ir .Region ,
70+ ):
71+ with interp_ .state .new_frame (interp_ .new_frame (stmt )) as body_frame :
72+ body_frame .entries .update (frame .entries )
73+ body_frame .set (body .blocks [0 ].args [0 ], cond )
74+ results = interp_ .run_ssacfg_region (body_frame , body )
75+
76+ if not body_frame .frame_is_not_pure :
77+ frame .should_be_pure .add (stmt )
78+ return body_frame , results
6479
6580 @interp .impl (For )
6681 def for_loop (
@@ -70,33 +85,44 @@ def for_loop(
7085 stmt : For ,
7186 ):
7287 iterable = frame .get (stmt .iterable )
73- loop_vars = frame .get_values (stmt .initializers )
74- block_args = stmt .body .blocks [0 ].args
75-
7688 if isinstance (iterable , const .Value ):
77- frame_is_not_pure = False
78- if not isinstance (iterable .data , Iterable ):
79- raise interp .InterpreterError (
80- f"Expected iterable, got { type (iterable .data )} "
81- )
82- for value in iterable .data :
83- with interp_ .state .new_frame (interp_ .new_frame (stmt )) as body_frame :
84- body_frame .entries .update (frame .entries )
85- body_frame .set_values (
86- block_args ,
87- (const .Value (value ),) + loop_vars ,
88- )
89- loop_vars = interp_ .run_ssacfg_region (body_frame , stmt .body )
90-
91- if body_frame .frame_is_not_pure :
92- frame_is_not_pure = True
93- if loop_vars is None :
94- loop_vars = ()
95- elif isinstance (loop_vars , interp .ReturnValue ):
96- return loop_vars
97-
98- if not frame_is_not_pure :
99- frame .should_be_pure .add (stmt )
100- return loop_vars
89+ return self ._prop_const_iterable_forloop (interp_ , frame , stmt , iterable )
10190 else : # TODO: support other iteration
10291 return tuple (interp_ .lattice .top () for _ in stmt .results )
92+
93+ def _prop_const_iterable_forloop (
94+ self ,
95+ interp_ : const .Propagate ,
96+ frame : const .Frame ,
97+ stmt : For ,
98+ iterable : const .Value ,
99+ ):
100+ frame_is_not_pure = False
101+ if not isinstance (iterable .data , Iterable ):
102+ raise interp .InterpreterError (
103+ f"Expected iterable, got { type (iterable .data )} "
104+ )
105+
106+ loop_vars = frame .get_values (stmt .initializers )
107+ body_block = stmt .body .blocks [0 ]
108+ block_args = body_block .args
109+
110+ for value in iterable .data :
111+ with interp_ .state .new_frame (interp_ .new_frame (stmt )) as body_frame :
112+ body_frame .entries .update (frame .entries )
113+ body_frame .set_values (
114+ block_args ,
115+ (const .Value (value ),) + loop_vars ,
116+ )
117+ loop_vars = interp_ .run_ssacfg_region (body_frame , stmt .body )
118+
119+ if body_frame .frame_is_not_pure :
120+ frame_is_not_pure = True
121+ if loop_vars is None :
122+ loop_vars = ()
123+ elif isinstance (loop_vars , interp .ReturnValue ):
124+ return loop_vars
125+
126+ if not frame_is_not_pure :
127+ frame .should_be_pure .add (stmt )
128+ return loop_vars
0 commit comments