@@ -154,7 +154,34 @@ pub enum SimpleLine {
154154pub fn simplify_program ( mut program : Program ) -> SimpleProgram {
155155 check_program_scoping ( & program) ;
156156 handle_inlined_functions ( & mut program) ;
157- handle_const_arguments ( & mut program) ;
157+
158+ // Iterate between unrolling and const argument handling until fixed point
159+ let mut unroll_counter = Counter :: new ( ) ;
160+ let mut max_iterations = 100 ;
161+ loop {
162+ let mut any_change = false ;
163+
164+ any_change |= unroll_loops_in_program ( & mut program, & mut unroll_counter) ;
165+ any_change |= handle_const_arguments ( & mut program) ;
166+
167+ max_iterations -= 1 ;
168+ assert ! ( max_iterations > 0 , "Too many iterations while simplifying program" ) ;
169+ if !any_change {
170+ break ;
171+ }
172+ }
173+
174+ // Remove all const functions - they should all have been specialized by now
175+ let const_func_names: Vec < _ > = program
176+ . functions
177+ . iter ( )
178+ . filter ( |( _, func) | func. has_const_arguments ( ) )
179+ . map ( |( name, _) | name. clone ( ) )
180+ . collect ( ) ;
181+ for name in const_func_names {
182+ program. functions . remove ( & name) ;
183+ }
184+
158185 let mut new_functions = BTreeMap :: new ( ) ;
159186 let mut counters = Counters :: default ( ) ;
160187 let mut const_malloc = ConstMalloc :: default ( ) ;
@@ -197,6 +224,88 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram {
197224 }
198225}
199226
227+ fn unroll_loops_in_program ( program : & mut Program , unroll_counter : & mut Counter ) -> bool {
228+ let mut changed = false ;
229+ for func in program. functions . values_mut ( ) {
230+ changed |= unroll_loops_in_lines ( & mut func. body , & program. const_arrays , unroll_counter) ;
231+ }
232+ changed
233+ }
234+
235+ fn unroll_loops_in_lines (
236+ lines : & mut Vec < Line > ,
237+ const_arrays : & BTreeMap < String , Vec < usize > > ,
238+ unroll_counter : & mut Counter ,
239+ ) -> bool {
240+ let mut changed = false ;
241+ let mut i = 0 ;
242+ while i < lines. len ( ) {
243+ // First, recursively process nested structures
244+ match & mut lines[ i] {
245+ Line :: ForLoop { body, .. } => {
246+ changed |= unroll_loops_in_lines ( body, const_arrays, unroll_counter) ;
247+ }
248+ Line :: IfCondition {
249+ then_branch,
250+ else_branch,
251+ ..
252+ } => {
253+ changed |= unroll_loops_in_lines ( then_branch, const_arrays, unroll_counter) ;
254+ changed |= unroll_loops_in_lines ( else_branch, const_arrays, unroll_counter) ;
255+ }
256+ Line :: Match { arms, .. } => {
257+ for ( _, arm_body) in arms {
258+ changed |= unroll_loops_in_lines ( arm_body, const_arrays, unroll_counter) ;
259+ }
260+ }
261+ _ => { }
262+ }
263+
264+ // Now try to unroll if it's an unrollable loop
265+ if let Line :: ForLoop {
266+ iterator,
267+ start,
268+ end,
269+ body,
270+ rev,
271+ unroll : true ,
272+ line_number : _,
273+ } = & lines[ i]
274+ && let ( Some ( start_val) , Some ( end_val) ) = ( start. naive_eval ( const_arrays) , end. naive_eval ( const_arrays) )
275+ {
276+ let start_usize = start_val. to_usize ( ) ;
277+ let end_usize = end_val. to_usize ( ) ;
278+ let unroll_index = unroll_counter. next ( ) ;
279+
280+ let ( internal_vars, _) = find_variable_usage ( body, const_arrays) ;
281+
282+ let mut range: Vec < _ > = ( start_usize..end_usize) . collect ( ) ;
283+ if * rev {
284+ range. reverse ( ) ;
285+ }
286+
287+ let iterator = iterator. clone ( ) ;
288+ let body = body. clone ( ) ;
289+
290+ let mut unrolled = Vec :: new ( ) ;
291+ for j in range {
292+ let mut body_copy = body. clone ( ) ;
293+ replace_vars_for_unroll ( & mut body_copy, & iterator, unroll_index, j, & internal_vars) ;
294+ unrolled. extend ( body_copy) ;
295+ }
296+
297+ let num_inserted = unrolled. len ( ) ;
298+ lines. splice ( i..=i, unrolled) ;
299+ changed = true ;
300+ i += num_inserted;
301+ continue ;
302+ }
303+
304+ i += 1 ;
305+ }
306+ changed
307+ }
308+
200309/// Analyzes a simplified function to verify that it returns on each code path.
201310fn check_function_always_returns ( func : & SimpleFunction ) {
202311 check_block_always_returns ( & func. name , & func. instructions ) ;
@@ -463,7 +572,6 @@ fn check_condition_scoping(condition: &Condition, ctx: &Context) {
463572struct Counters {
464573 aux_vars : usize ,
465574 loops : usize ,
466- unrolls : usize ,
467575}
468576
469577#[ derive( Debug , Clone , Default ) ]
@@ -754,37 +862,7 @@ fn simplify_lines(
754862 unroll,
755863 line_number,
756864 } => {
757- if * unroll {
758- let ( internal_variables, _) = find_variable_usage ( body, const_arrays) ;
759- let mut unrolled_lines = Vec :: new ( ) ;
760- let start_evaluated = start. naive_eval ( const_arrays) . unwrap ( ) . to_usize ( ) ;
761- let end_evaluated = end. naive_eval ( const_arrays) . unwrap ( ) . to_usize ( ) ;
762- let unroll_index = counters. unrolls ;
763- counters. unrolls += 1 ;
764-
765- let mut range = ( start_evaluated..end_evaluated) . collect :: < Vec < _ > > ( ) ;
766- if * rev {
767- range. reverse ( ) ;
768- }
769-
770- for i in range {
771- let mut body_copy = body. clone ( ) ;
772- replace_vars_for_unroll ( & mut body_copy, iterator, unroll_index, i, & internal_variables) ;
773- unrolled_lines. extend ( simplify_lines (
774- functions,
775- 0 ,
776- & body_copy,
777- counters,
778- new_functions,
779- in_a_loop,
780- array_manager,
781- const_malloc,
782- const_arrays,
783- ) ) ;
784- }
785- res. extend ( unrolled_lines) ;
786- continue ;
787- }
865+ assert ! ( !* unroll, "Unrolled loops should have been handled already" ) ;
788866
789867 if * rev {
790868 unimplemented ! ( "Reverse for non-unrolled loops are not implemented yet" ) ;
@@ -2034,7 +2112,8 @@ fn handle_inlined_functions_helper(
20342112 }
20352113}
20362114
2037- fn handle_const_arguments ( program : & mut Program ) {
2115+ fn handle_const_arguments ( program : & mut Program ) -> bool {
2116+ let mut any_changes = false ;
20382117 let mut new_functions = BTreeMap :: < String , Function > :: new ( ) ;
20392118 let constant_functions = program
20402119 . functions
@@ -2046,7 +2125,7 @@ fn handle_const_arguments(program: &mut Program) {
20462125 // First pass: process non-const functions that call const functions
20472126 for func in program. functions . values_mut ( ) {
20482127 if !func. has_const_arguments ( ) {
2049- handle_const_arguments_helper (
2128+ any_changes |= handle_const_arguments_helper (
20502129 & mut func. body ,
20512130 & constant_functions,
20522131 & mut new_functions,
@@ -2055,7 +2134,7 @@ fn handle_const_arguments(program: &mut Program) {
20552134 }
20562135 }
20572136
2058- // Process newly created const functions recursively until no more changes
2137+ // Process newly created functions recursively until no more changes
20592138 let mut changed = true ;
20602139 let mut const_depth = 0 ;
20612140 while changed {
@@ -2078,6 +2157,7 @@ fn handle_const_arguments(program: &mut Program) {
20782157 ) ;
20792158 if additional_functions. len ( ) > initial_count {
20802159 changed = true ;
2160+ any_changes = true ;
20812161 }
20822162 }
20832163 }
@@ -2087,26 +2167,32 @@ fn handle_const_arguments(program: &mut Program) {
20872167 if let std:: collections:: btree_map:: Entry :: Vacant ( e) = new_functions. entry ( name) {
20882168 e. insert ( func) ;
20892169 changed = true ;
2170+ any_changes = true ;
20902171 }
20912172 }
20922173 }
20932174
2175+ any_changes |= !new_functions. is_empty ( ) ;
2176+
20942177 for ( name, func) in new_functions {
2095- assert ! ( !program. functions. contains_key( & name) , ) ;
2178+ assert ! ( !program. functions. contains_key( & name) ) ;
20962179 program. functions . insert ( name, func) ;
20972180 }
2098- for const_func in constant_functions. keys ( ) {
2099- program. functions . remove ( const_func) ;
2100- }
2181+
2182+ // DON'T remove const functions here - they might be needed in subsequent iterations
2183+ // They will be removed at the end of simplify_program
2184+
2185+ any_changes
21012186}
21022187
21032188fn handle_const_arguments_helper (
21042189 lines : & mut [ Line ] ,
21052190 constant_functions : & BTreeMap < String , Function > ,
21062191 new_functions : & mut BTreeMap < String , Function > ,
21072192 const_arrays : & BTreeMap < String , Vec < usize > > ,
2108- ) {
2109- for line in lines {
2193+ ) -> bool {
2194+ let mut changed = false ;
2195+ ' outer: for line in lines {
21102196 match line {
21112197 Line :: FunctionCall {
21122198 function_name,
@@ -2115,16 +2201,19 @@ fn handle_const_arguments_helper(
21152201 line_number : _,
21162202 } => {
21172203 if let Some ( func) = constant_functions. get ( function_name) {
2118- // If the function has constant arguments, we need to handle them
2204+ // Check if all const arguments can be evaluated
21192205 let mut const_evals = Vec :: new ( ) ;
21202206 for ( arg_expr, ( arg_var, is_constant) ) in args. iter ( ) . zip ( & func. arguments ) {
21212207 if * is_constant {
2122- let const_eval = arg_expr
2123- . naive_eval ( const_arrays)
2124- . unwrap_or_else ( || panic ! ( "Failed to evaluate constant argument: {arg_expr}" ) ) ;
2125- const_evals. push ( ( arg_var. clone ( ) , const_eval) ) ;
2208+ if let Some ( const_eval) = arg_expr. naive_eval ( const_arrays) {
2209+ const_evals. push ( ( arg_var. clone ( ) , const_eval) ) ;
2210+ } else {
2211+ // Skip this call, will be handled in a later pass after more unrolling
2212+ continue ' outer;
2213+ }
21262214 }
21272215 }
2216+
21282217 let const_funct_name = format ! (
21292218 "{function_name}_{}" ,
21302219 const_evals
@@ -2144,6 +2233,8 @@ fn handle_const_arguments_helper(
21442233 . map ( |( arg_expr, _) | arg_expr. clone ( ) )
21452234 . collect ( ) ;
21462235
2236+ changed = true ;
2237+
21472238 if new_functions. contains_key ( & const_funct_name) {
21482239 continue ;
21492240 }
@@ -2173,21 +2264,21 @@ fn handle_const_arguments_helper(
21732264 else_branch,
21742265 ..
21752266 } => {
2176- handle_const_arguments_helper ( then_branch, constant_functions, new_functions, const_arrays) ;
2177- handle_const_arguments_helper ( else_branch, constant_functions, new_functions, const_arrays) ;
2267+ changed |= handle_const_arguments_helper ( then_branch, constant_functions, new_functions, const_arrays) ;
2268+ changed |= handle_const_arguments_helper ( else_branch, constant_functions, new_functions, const_arrays) ;
21782269 }
21792270 Line :: ForLoop { body, unroll : _, .. } => {
2180- // TODO we should unroll before const arguments handling
2181- handle_const_arguments_helper ( body, constant_functions, new_functions, const_arrays) ;
2271+ changed |= handle_const_arguments_helper ( body, constant_functions, new_functions, const_arrays) ;
21822272 }
21832273 Line :: Match { arms, .. } => {
21842274 for ( _, arm) in arms {
2185- handle_const_arguments_helper ( arm, constant_functions, new_functions, const_arrays) ;
2275+ changed |= handle_const_arguments_helper ( arm, constant_functions, new_functions, const_arrays) ;
21862276 }
21872277 }
21882278 _ => { }
21892279 }
21902280 }
2281+ changed
21912282}
21922283
21932284fn replace_vars_by_const_in_expr ( expr : & mut Expression , map : & BTreeMap < Var , F > ) {
0 commit comments