Skip to content

Commit d756c08

Browse files
committed
zkDSL new feature: assign the result of a function call directly to a given index in a memory slice, without declaring intermediary variables (my_vector[3] = my_func(45, 34);)
1 parent cb86e70 commit d756c08

File tree

5 files changed

+278
-49
lines changed

5 files changed

+278
-49
lines changed

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 135 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use crate::{
22
Counter, F,
33
ir::HighLevelOperation,
44
lang::{
5-
AssumeBoolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue, Context, Expression, Function,
6-
Line, Program, Scope, SimpleExpr, Var,
5+
AssignmentTarget, AssumeBoolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue, Context,
6+
Expression, Function, Line, Program, Scope, SimpleExpr, Var,
77
},
88
};
99
use lean_vm::{Boolean, BooleanExpr, SourceLineNumber, Table, TableT};
@@ -323,12 +323,26 @@ fn check_block_scoping(block: &[Line], ctx: &mut Context) {
323323
check_expr_scoping(arg, ctx);
324324
}
325325
let last_scope = ctx.scopes.last_mut().unwrap();
326-
for var in return_data {
327-
assert!(
328-
!last_scope.vars.contains(var),
329-
"Variable declared multiple times in the same scope: {var}",
330-
);
331-
last_scope.vars.insert(var.clone());
326+
for target in return_data {
327+
match target {
328+
AssignmentTarget::Var(var) => {
329+
assert!(
330+
!last_scope.vars.contains(var),
331+
"Variable declared multiple times in the same scope: {var}",
332+
);
333+
last_scope.vars.insert(var.clone());
334+
}
335+
AssignmentTarget::ArrayAccess { .. } => {}
336+
}
337+
}
338+
for target in return_data {
339+
match target {
340+
AssignmentTarget::Var(_) => {}
341+
AssignmentTarget::ArrayAccess { array, index } => {
342+
check_simple_expr_scoping(array, ctx);
343+
check_expr_scoping(index, ctx);
344+
}
345+
}
332346
}
333347
}
334348
Line::FunctionRet { return_data } => {
@@ -888,12 +902,45 @@ fn simplify_lines(
888902
.iter()
889903
.map(|arg| simplify_expr(arg, &mut res, counters, array_manager, const_malloc, const_arrays))
890904
.collect::<Vec<_>>();
905+
906+
// Generate temp vars for all return values and track array targets
907+
let mut temp_vars = Vec::new();
908+
let mut array_targets: Vec<(usize, SimpleExpr, Box<Expression>)> = Vec::new();
909+
910+
for (i, target) in return_data.iter().enumerate() {
911+
match target {
912+
AssignmentTarget::Var(var) => {
913+
temp_vars.push(var.clone());
914+
}
915+
AssignmentTarget::ArrayAccess { array, index } => {
916+
let temp_var = format!("@ret_temp_{}", counters.aux_vars);
917+
counters.aux_vars += 1;
918+
temp_vars.push(temp_var);
919+
array_targets.push((i, array.clone(), index.clone()));
920+
}
921+
}
922+
}
923+
891924
res.push(SimpleLine::FunctionCall {
892925
function_name: function_name.clone(),
893926
args: simplified_args,
894-
return_data: return_data.clone(),
927+
return_data: temp_vars.clone(),
895928
line_number: *line_number,
896929
});
930+
931+
// For array access targets, add DEREF instructions to copy temp to array element
932+
for (i, array, index) in array_targets {
933+
handle_array_assignment(
934+
counters,
935+
&mut res,
936+
array,
937+
&index,
938+
ArrayAccessType::ArrayIsAssigned(Expression::Value(SimpleExpr::Var(temp_vars[i].clone()))),
939+
array_manager,
940+
const_malloc,
941+
const_arrays,
942+
);
943+
}
897944
}
898945
Line::FunctionRet { return_data } => {
899946
assert!(!in_a_loop, "Function return inside a loop is not currently supported");
@@ -1185,7 +1232,22 @@ pub fn find_variable_usage(
11851232
for arg in args {
11861233
on_new_expr(arg, &internal_vars, &mut external_vars);
11871234
}
1188-
internal_vars.extend(return_data.iter().cloned());
1235+
for target in return_data {
1236+
match target {
1237+
AssignmentTarget::Var(var) => {
1238+
internal_vars.insert(var.clone());
1239+
}
1240+
AssignmentTarget::ArrayAccess { array, index } => {
1241+
if let SimpleExpr::Var(var) = array {
1242+
assert!(!const_arrays.contains_key(var), "Cannot assign to const array");
1243+
if !internal_vars.contains(var) {
1244+
external_vars.insert(var.clone());
1245+
}
1246+
}
1247+
on_new_expr(index, &internal_vars, &mut external_vars);
1248+
}
1249+
}
1250+
}
11891251
}
11901252
Line::Assert { boolean, .. } => {
11911253
on_new_condition(
@@ -1288,7 +1350,12 @@ fn inline_expr(expr: &mut Expression, args: &BTreeMap<Var, SimpleExpr>, inlining
12881350
}
12891351
}
12901352

1291-
pub fn inline_lines(lines: &mut Vec<Line>, args: &BTreeMap<Var, SimpleExpr>, res: &[Var], inlining_count: usize) {
1353+
fn inline_lines(
1354+
lines: &mut Vec<Line>,
1355+
args: &BTreeMap<Var, SimpleExpr>,
1356+
res: &[AssignmentTarget],
1357+
inlining_count: usize,
1358+
) {
12921359
let inline_comparison = |comparison: &mut BooleanExpr<Expression>| {
12931360
inline_expr(&mut comparison.left, args, inlining_count);
12941361
inline_expr(&mut comparison.right, args, inlining_count);
@@ -1342,8 +1409,16 @@ pub fn inline_lines(lines: &mut Vec<Line>, args: &BTreeMap<Var, SimpleExpr>, res
13421409
for arg in func_args {
13431410
inline_expr(arg, args, inlining_count);
13441411
}
1345-
for return_var in return_data {
1346-
inline_internal_var(return_var);
1412+
for target in return_data {
1413+
match target {
1414+
AssignmentTarget::Var(var) => {
1415+
inline_internal_var(var);
1416+
}
1417+
AssignmentTarget::ArrayAccess { array, index } => {
1418+
inline_simple_expr(array, args, inlining_count);
1419+
inline_expr(index, args, inlining_count);
1420+
}
1421+
}
13471422
}
13481423
}
13491424
Line::Assert { boolean, .. } => {
@@ -1359,9 +1434,16 @@ pub fn inline_lines(lines: &mut Vec<Line>, args: &BTreeMap<Var, SimpleExpr>, res
13591434
i,
13601435
res.iter()
13611436
.zip(return_data)
1362-
.map(|(res_var, expr)| Line::Assignment {
1363-
var: res_var.clone(),
1364-
value: expr.clone(),
1437+
.map(|(target, expr)| match target {
1438+
AssignmentTarget::Var(res_var) => Line::Assignment {
1439+
var: res_var.clone(),
1440+
value: expr.clone(),
1441+
},
1442+
AssignmentTarget::ArrayAccess { array, index } => Line::ArrayAssign {
1443+
array: array.clone(),
1444+
index: (**index).clone(),
1445+
value: expr.clone(),
1446+
},
13651447
})
13661448
.collect::<Vec<_>>(),
13671449
));
@@ -1728,12 +1810,29 @@ fn replace_vars_for_unroll(
17281810
return_data,
17291811
line_number: _,
17301812
} => {
1731-
// Function calls are not unrolled, so we don't need to change them
17321813
for arg in args {
17331814
replace_vars_for_unroll_in_expr(arg, iterator, unroll_index, iterator_value, internal_vars);
17341815
}
1735-
for ret in return_data {
1736-
*ret = format!("@unrolled_{unroll_index}_{iterator_value}_{ret}");
1816+
for target in return_data {
1817+
match target {
1818+
AssignmentTarget::Var(ret) => {
1819+
*ret = format!("@unrolled_{unroll_index}_{iterator_value}_{ret}");
1820+
}
1821+
AssignmentTarget::ArrayAccess { array, index } => {
1822+
if let SimpleExpr::Var(array_var) = array
1823+
&& internal_vars.contains(array_var)
1824+
{
1825+
*array_var = format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}");
1826+
}
1827+
replace_vars_for_unroll_in_expr(
1828+
index,
1829+
iterator,
1830+
unroll_index,
1831+
iterator_value,
1832+
internal_vars,
1833+
);
1834+
}
1835+
}
17371836
}
17381837
}
17391838
Line::FunctionRet { return_data } => {
@@ -1869,8 +1968,11 @@ fn handle_inlined_functions_helper(
18691968
if let Some(func) = inlined_functions.get(&*function_name) {
18701969
let mut inlined_lines = vec![];
18711970

1872-
for var in return_data.iter() {
1873-
inlined_lines.push(Line::ForwardDeclaration { var: var.clone() });
1971+
// Only add forward declarations for variable targets, not array accesses
1972+
for target in return_data.iter() {
1973+
if let AssignmentTarget::Var(var) = target {
1974+
inlined_lines.push(Line::ForwardDeclaration { var: var.clone() });
1975+
}
18741976
}
18751977

18761978
let mut simplified_args = vec![];
@@ -2148,8 +2250,18 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap<Var, F>) {
21482250
for arg in args {
21492251
replace_vars_by_const_in_expr(arg, map);
21502252
}
2151-
for ret in return_data {
2152-
assert!(!map.contains_key(ret), "Return variable {ret} is a constant");
2253+
for target in return_data {
2254+
match target {
2255+
AssignmentTarget::Var(ret) => {
2256+
assert!(!map.contains_key(ret), "Return variable {ret} is a constant");
2257+
}
2258+
AssignmentTarget::ArrayAccess { array, index } => {
2259+
if let SimpleExpr::Var(array_var) = array {
2260+
assert!(!map.contains_key(array_var), "Array {array_var} is a constant");
2261+
}
2262+
replace_vars_by_const_in_expr(index, map);
2263+
}
2264+
}
21532265
}
21542266
}
21552267
Line::IfCondition {

crates/lean_compiler/src/grammar.pest

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ match_arm = { pattern ~ "=>" ~ "{" ~ statement* ~ "}" }
6565
pattern = { constant_value }
6666

6767
function_call = { function_res? ~ identifier ~ "(" ~ tuple_expression? ~ ")" ~ ";" }
68-
function_res = { var_list ~ "=" }
69-
var_list = { identifier ~ ("," ~ identifier)* }
68+
function_res = { return_target_list ~ "=" }
69+
return_target_list = { return_target ~ ("," ~ return_target)* }
70+
return_target = { array_access_expr | identifier }
7071

7172
assert_eq_statement = { "assert" ~ add_expr ~ "==" ~ add_expr ~ ";" }
7273
assert_not_eq_statement = { "assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" }

crates/lean_compiler/src/lang.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,21 @@ impl Expression {
339339
}
340340
}
341341

342+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
343+
pub enum AssignmentTarget {
344+
Var(Var),
345+
ArrayAccess { array: SimpleExpr, index: Box<Expression> },
346+
}
347+
348+
impl Display for AssignmentTarget {
349+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
350+
match self {
351+
Self::Var(var) => write!(f, "{}", var),
352+
Self::ArrayAccess { array, index } => write!(f, "{}[{}]", array, index),
353+
}
354+
}
355+
}
356+
342357
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
343358
pub enum Line {
344359
Match {
@@ -381,7 +396,7 @@ pub enum Line {
381396
FunctionCall {
382397
function_name: String,
383398
args: Vec<Expression>,
384-
return_data: Vec<Var>,
399+
return_data: Vec<AssignmentTarget>, // Changed from Vec<Var>
385400
line_number: SourceLineNumber,
386401
},
387402
FunctionRet {
@@ -570,7 +585,7 @@ impl Line {
570585
let args_str = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>().join(", ");
571586
let return_data_str = return_data
572587
.iter()
573-
.map(|var| var.to_string())
588+
.map(|target| target.to_string())
574589
.collect::<Vec<_>>()
575590
.join(", ");
576591

0 commit comments

Comments
 (0)