Skip to content

Commit 35b06d5

Browse files
committed
zkDSL compiler: enable intertwined loop unrolling / constant arguments
1 parent d756c08 commit 35b06d5

File tree

2 files changed

+186
-52
lines changed

2 files changed

+186
-52
lines changed

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 143 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,34 @@ pub enum SimpleLine {
154154
pub fn simplify_program(mut program: Program) -> SimpleProgram {
155155
check_program_scoping(&program);
156156
handle_inlined_functions(&mut program);
157-
handle_const_arguments(&mut program);
157+
158+
// Iterate between unrolling and const argument handling until fixed point
159+
let mut unroll_counter = Counter::new();
160+
let mut max_iterations = 100;
161+
loop {
162+
let mut any_change = false;
163+
164+
any_change |= unroll_loops_in_program(&mut program, &mut unroll_counter);
165+
any_change |= handle_const_arguments(&mut program);
166+
167+
max_iterations -= 1;
168+
assert!(max_iterations > 0, "Too many iterations while simplifying program");
169+
if !any_change {
170+
break;
171+
}
172+
}
173+
174+
// Remove all const functions - they should all have been specialized by now
175+
let const_func_names: Vec<_> = program
176+
.functions
177+
.iter()
178+
.filter(|(_, func)| func.has_const_arguments())
179+
.map(|(name, _)| name.clone())
180+
.collect();
181+
for name in const_func_names {
182+
program.functions.remove(&name);
183+
}
184+
158185
let mut new_functions = BTreeMap::new();
159186
let mut counters = Counters::default();
160187
let mut const_malloc = ConstMalloc::default();
@@ -197,6 +224,88 @@ pub fn simplify_program(mut program: Program) -> SimpleProgram {
197224
}
198225
}
199226

227+
fn unroll_loops_in_program(program: &mut Program, unroll_counter: &mut Counter) -> bool {
228+
let mut changed = false;
229+
for func in program.functions.values_mut() {
230+
changed |= unroll_loops_in_lines(&mut func.body, &program.const_arrays, unroll_counter);
231+
}
232+
changed
233+
}
234+
235+
fn unroll_loops_in_lines(
236+
lines: &mut Vec<Line>,
237+
const_arrays: &BTreeMap<String, Vec<usize>>,
238+
unroll_counter: &mut Counter,
239+
) -> bool {
240+
let mut changed = false;
241+
let mut i = 0;
242+
while i < lines.len() {
243+
// First, recursively process nested structures
244+
match &mut lines[i] {
245+
Line::ForLoop { body, .. } => {
246+
changed |= unroll_loops_in_lines(body, const_arrays, unroll_counter);
247+
}
248+
Line::IfCondition {
249+
then_branch,
250+
else_branch,
251+
..
252+
} => {
253+
changed |= unroll_loops_in_lines(then_branch, const_arrays, unroll_counter);
254+
changed |= unroll_loops_in_lines(else_branch, const_arrays, unroll_counter);
255+
}
256+
Line::Match { arms, .. } => {
257+
for (_, arm_body) in arms {
258+
changed |= unroll_loops_in_lines(arm_body, const_arrays, unroll_counter);
259+
}
260+
}
261+
_ => {}
262+
}
263+
264+
// Now try to unroll if it's an unrollable loop
265+
if let Line::ForLoop {
266+
iterator,
267+
start,
268+
end,
269+
body,
270+
rev,
271+
unroll: true,
272+
line_number: _,
273+
} = &lines[i]
274+
&& let (Some(start_val), Some(end_val)) = (start.naive_eval(const_arrays), end.naive_eval(const_arrays))
275+
{
276+
let start_usize = start_val.to_usize();
277+
let end_usize = end_val.to_usize();
278+
let unroll_index = unroll_counter.next();
279+
280+
let (internal_vars, _) = find_variable_usage(body, const_arrays);
281+
282+
let mut range: Vec<_> = (start_usize..end_usize).collect();
283+
if *rev {
284+
range.reverse();
285+
}
286+
287+
let iterator = iterator.clone();
288+
let body = body.clone();
289+
290+
let mut unrolled = Vec::new();
291+
for j in range {
292+
let mut body_copy = body.clone();
293+
replace_vars_for_unroll(&mut body_copy, &iterator, unroll_index, j, &internal_vars);
294+
unrolled.extend(body_copy);
295+
}
296+
297+
let num_inserted = unrolled.len();
298+
lines.splice(i..=i, unrolled);
299+
changed = true;
300+
i += num_inserted;
301+
continue;
302+
}
303+
304+
i += 1;
305+
}
306+
changed
307+
}
308+
200309
/// Analyzes a simplified function to verify that it returns on each code path.
201310
fn check_function_always_returns(func: &SimpleFunction) {
202311
check_block_always_returns(&func.name, &func.instructions);
@@ -463,7 +572,6 @@ fn check_condition_scoping(condition: &Condition, ctx: &Context) {
463572
struct Counters {
464573
aux_vars: usize,
465574
loops: usize,
466-
unrolls: usize,
467575
}
468576

469577
#[derive(Debug, Clone, Default)]
@@ -754,37 +862,7 @@ fn simplify_lines(
754862
unroll,
755863
line_number,
756864
} => {
757-
if *unroll {
758-
let (internal_variables, _) = find_variable_usage(body, const_arrays);
759-
let mut unrolled_lines = Vec::new();
760-
let start_evaluated = start.naive_eval(const_arrays).unwrap().to_usize();
761-
let end_evaluated = end.naive_eval(const_arrays).unwrap().to_usize();
762-
let unroll_index = counters.unrolls;
763-
counters.unrolls += 1;
764-
765-
let mut range = (start_evaluated..end_evaluated).collect::<Vec<_>>();
766-
if *rev {
767-
range.reverse();
768-
}
769-
770-
for i in range {
771-
let mut body_copy = body.clone();
772-
replace_vars_for_unroll(&mut body_copy, iterator, unroll_index, i, &internal_variables);
773-
unrolled_lines.extend(simplify_lines(
774-
functions,
775-
0,
776-
&body_copy,
777-
counters,
778-
new_functions,
779-
in_a_loop,
780-
array_manager,
781-
const_malloc,
782-
const_arrays,
783-
));
784-
}
785-
res.extend(unrolled_lines);
786-
continue;
787-
}
865+
assert!(!*unroll, "Unrolled loops should have been handled already");
788866

789867
if *rev {
790868
unimplemented!("Reverse for non-unrolled loops are not implemented yet");
@@ -2034,7 +2112,8 @@ fn handle_inlined_functions_helper(
20342112
}
20352113
}
20362114

2037-
fn handle_const_arguments(program: &mut Program) {
2115+
fn handle_const_arguments(program: &mut Program) -> bool {
2116+
let mut any_changes = false;
20382117
let mut new_functions = BTreeMap::<String, Function>::new();
20392118
let constant_functions = program
20402119
.functions
@@ -2046,7 +2125,7 @@ fn handle_const_arguments(program: &mut Program) {
20462125
// First pass: process non-const functions that call const functions
20472126
for func in program.functions.values_mut() {
20482127
if !func.has_const_arguments() {
2049-
handle_const_arguments_helper(
2128+
any_changes |= handle_const_arguments_helper(
20502129
&mut func.body,
20512130
&constant_functions,
20522131
&mut new_functions,
@@ -2055,7 +2134,7 @@ fn handle_const_arguments(program: &mut Program) {
20552134
}
20562135
}
20572136

2058-
// Process newly created const functions recursively until no more changes
2137+
// Process newly created functions recursively until no more changes
20592138
let mut changed = true;
20602139
let mut const_depth = 0;
20612140
while changed {
@@ -2078,6 +2157,7 @@ fn handle_const_arguments(program: &mut Program) {
20782157
);
20792158
if additional_functions.len() > initial_count {
20802159
changed = true;
2160+
any_changes = true;
20812161
}
20822162
}
20832163
}
@@ -2087,26 +2167,32 @@ fn handle_const_arguments(program: &mut Program) {
20872167
if let std::collections::btree_map::Entry::Vacant(e) = new_functions.entry(name) {
20882168
e.insert(func);
20892169
changed = true;
2170+
any_changes = true;
20902171
}
20912172
}
20922173
}
20932174

2175+
any_changes |= !new_functions.is_empty();
2176+
20942177
for (name, func) in new_functions {
2095-
assert!(!program.functions.contains_key(&name),);
2178+
assert!(!program.functions.contains_key(&name));
20962179
program.functions.insert(name, func);
20972180
}
2098-
for const_func in constant_functions.keys() {
2099-
program.functions.remove(const_func);
2100-
}
2181+
2182+
// DON'T remove const functions here - they might be needed in subsequent iterations
2183+
// They will be removed at the end of simplify_program
2184+
2185+
any_changes
21012186
}
21022187

21032188
fn handle_const_arguments_helper(
21042189
lines: &mut [Line],
21052190
constant_functions: &BTreeMap<String, Function>,
21062191
new_functions: &mut BTreeMap<String, Function>,
21072192
const_arrays: &BTreeMap<String, Vec<usize>>,
2108-
) {
2109-
for line in lines {
2193+
) -> bool {
2194+
let mut changed = false;
2195+
'outer: for line in lines {
21102196
match line {
21112197
Line::FunctionCall {
21122198
function_name,
@@ -2115,16 +2201,19 @@ fn handle_const_arguments_helper(
21152201
line_number: _,
21162202
} => {
21172203
if let Some(func) = constant_functions.get(function_name) {
2118-
// If the function has constant arguments, we need to handle them
2204+
// Check if all const arguments can be evaluated
21192205
let mut const_evals = Vec::new();
21202206
for (arg_expr, (arg_var, is_constant)) in args.iter().zip(&func.arguments) {
21212207
if *is_constant {
2122-
let const_eval = arg_expr
2123-
.naive_eval(const_arrays)
2124-
.unwrap_or_else(|| panic!("Failed to evaluate constant argument: {arg_expr}"));
2125-
const_evals.push((arg_var.clone(), const_eval));
2208+
if let Some(const_eval) = arg_expr.naive_eval(const_arrays) {
2209+
const_evals.push((arg_var.clone(), const_eval));
2210+
} else {
2211+
// Skip this call, will be handled in a later pass after more unrolling
2212+
continue 'outer;
2213+
}
21262214
}
21272215
}
2216+
21282217
let const_funct_name = format!(
21292218
"{function_name}_{}",
21302219
const_evals
@@ -2144,6 +2233,8 @@ fn handle_const_arguments_helper(
21442233
.map(|(arg_expr, _)| arg_expr.clone())
21452234
.collect();
21462235

2236+
changed = true;
2237+
21472238
if new_functions.contains_key(&const_funct_name) {
21482239
continue;
21492240
}
@@ -2173,21 +2264,21 @@ fn handle_const_arguments_helper(
21732264
else_branch,
21742265
..
21752266
} => {
2176-
handle_const_arguments_helper(then_branch, constant_functions, new_functions, const_arrays);
2177-
handle_const_arguments_helper(else_branch, constant_functions, new_functions, const_arrays);
2267+
changed |= handle_const_arguments_helper(then_branch, constant_functions, new_functions, const_arrays);
2268+
changed |= handle_const_arguments_helper(else_branch, constant_functions, new_functions, const_arrays);
21782269
}
21792270
Line::ForLoop { body, unroll: _, .. } => {
2180-
// TODO we should unroll before const arguments handling
2181-
handle_const_arguments_helper(body, constant_functions, new_functions, const_arrays);
2271+
changed |= handle_const_arguments_helper(body, constant_functions, new_functions, const_arrays);
21822272
}
21832273
Line::Match { arms, .. } => {
21842274
for (_, arm) in arms {
2185-
handle_const_arguments_helper(arm, constant_functions, new_functions, const_arrays);
2275+
changed |= handle_const_arguments_helper(arm, constant_functions, new_functions, const_arrays);
21862276
}
21872277
}
21882278
_ => {}
21892279
}
21902280
}
2281+
changed
21912282
}
21922283

21932284
fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap<Var, F>) {

crates/lean_compiler/tests/test_compiler.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,3 +838,46 @@ fn test_array_return_targets_with_expressions() {
838838
"#;
839839
compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
840840
}
841+
842+
#[test]
843+
fn intertwined_unrolled_loops_and_const_function_arguments() {
844+
let program = r#"
845+
const ARR = [10, 100];
846+
fn main() {
847+
buff = malloc(3);
848+
buff[0] = 0;
849+
for i in 0..2 unroll {
850+
res = f1(ARR[i]);
851+
buff[i + 1] = res;
852+
}
853+
assert buff[2] == 1390320454;
854+
return;
855+
}
856+
857+
fn f1(const x) -> 1 {
858+
buff = malloc(9);
859+
buff[0] = 1;
860+
for i in x..x+4 unroll {
861+
for j in i..i+2 unroll {
862+
index = (i - x) * 2 + (j - i);
863+
res = f2(i, j);
864+
buff[index+1] = buff[index] * res;
865+
}
866+
}
867+
return buff[8];
868+
}
869+
870+
fn f2(const x, const y) -> 1 {
871+
buff = malloc(7);
872+
buff[0] = 0;
873+
for i in x..x+2 unroll {
874+
for j in i..i+3 unroll {
875+
index = (i - x) * 3 + (j - i);
876+
buff[index+1] = buff[index] + i + j;
877+
}
878+
}
879+
return buff[4];
880+
}
881+
"#;
882+
compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
883+
}

0 commit comments

Comments
 (0)