Skip to content

Commit bb3806e

Browse files
committed
zkDSL new feature: multi-dimensional constant arrays
1 parent 0d62e33 commit bb3806e

File tree

8 files changed

+499
-138
lines changed

8 files changed

+499
-138
lines changed

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 92 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::{
55
AssignmentTarget, AssumeBoolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue, Context,
66
Expression, Function, Line, Program, Scope, SimpleExpr, Var,
77
},
8+
parser::ConstArrayValue,
89
};
910
use lean_vm::{Boolean, BooleanExpr, CustomHint, FileId, SourceLineNumber, SourceLocation, Table, TableT};
1011
use std::{
@@ -230,7 +231,7 @@ fn unroll_loops_in_program(program: &mut Program, unroll_counter: &mut Counter)
230231

231232
fn unroll_loops_in_lines(
232233
lines: &mut Vec<Line>,
233-
const_arrays: &BTreeMap<String, Vec<usize>>,
234+
const_arrays: &BTreeMap<String, ConstArrayValue>,
234235
unroll_counter: &mut Counter,
235236
) -> bool {
236237
let mut changed = false;
@@ -506,7 +507,9 @@ fn check_expr_scoping(expr: &Expression, ctx: &Context) {
506507
}
507508
Expression::ArrayAccess { array, index } => {
508509
check_simple_expr_scoping(array, ctx);
509-
check_expr_scoping(index, ctx);
510+
for idx in index {
511+
check_expr_scoping(idx, ctx);
512+
}
510513
}
511514
Expression::Binary {
512515
left,
@@ -593,7 +596,7 @@ fn simplify_lines(
593596
in_a_loop: bool,
594597
array_manager: &mut ArrayManager,
595598
const_malloc: &mut ConstMalloc,
596-
const_arrays: &BTreeMap<String, Vec<usize>>,
599+
const_arrays: &BTreeMap<String, ConstArrayValue>,
597600
) -> Vec<SimpleLine> {
598601
let mut res = Vec::new();
599602
for line in lines {
@@ -662,7 +665,7 @@ fn simplify_lines(
662665
counters,
663666
&mut res,
664667
array.clone(),
665-
index,
668+
std::slice::from_ref(index),
666669
ArrayAccessType::ArrayIsAssigned(value.clone()),
667670
array_manager,
668671
const_malloc,
@@ -724,6 +727,24 @@ fn simplify_lines(
724727
} else if let Ok(right) = right.clone().try_into() {
725728
(right, left)
726729
} else {
730+
// Both are constants - evaluate at compile time
731+
if let (SimpleExpr::Constant(left_const), SimpleExpr::Constant(right_const)) =
732+
(&left, &right)
733+
&& let (Some(left_val), Some(right_val)) =
734+
(left_const.naive_eval(), right_const.naive_eval())
735+
{
736+
if left_val == right_val {
737+
// Assertion passes at compile time, no code needed
738+
continue;
739+
} else {
740+
panic!(
741+
"Compile-time assertion failed: {} != {} (lines {})",
742+
left_val.to_usize(),
743+
right_val.to_usize(),
744+
line_number
745+
);
746+
}
747+
}
727748
panic!("Unsupported equality assertion: {left:?}, {right:?}")
728749
};
729750
res.push(SimpleLine::Assignment {
@@ -1003,7 +1024,7 @@ fn simplify_lines(
10031024
counters,
10041025
&mut res,
10051026
array,
1006-
&index,
1027+
&[*index],
10071028
ArrayAccessType::ArrayIsAssigned(Expression::Value(SimpleExpr::Var(temp_vars[i].clone()))),
10081029
array_manager,
10091030
const_malloc,
@@ -1120,7 +1141,7 @@ fn simplify_expr(
11201141
counters: &mut Counters,
11211142
array_manager: &mut ArrayManager,
11221143
const_malloc: &ConstMalloc,
1123-
const_arrays: &BTreeMap<String, Vec<usize>>,
1144+
const_arrays: &BTreeMap<String, ConstArrayValue>,
11241145
) -> SimpleExpr {
11251146
match expr {
11261147
Expression::Value(value) => value.simplify_if_const(),
@@ -1129,28 +1150,32 @@ fn simplify_expr(
11291150
if let SimpleExpr::Var(array_var) = array
11301151
&& let Some(arr) = const_arrays.get(array_var)
11311152
{
1132-
let simplified_index = simplify_expr(index, lines, counters, array_manager, const_malloc, const_arrays);
1133-
if let SimpleExpr::Constant(c) = &simplified_index
1134-
&& let Some(idx_val) = c.naive_eval()
1135-
{
1136-
let idx = idx_val.to_usize();
1137-
if idx < arr.len() {
1138-
return SimpleExpr::Constant(ConstExpression::from(arr[idx]));
1139-
} else {
1140-
panic!(
1141-
"Const array '{}' index {} out of bounds (length {})",
1142-
array_var,
1143-
idx,
1144-
arr.len()
1145-
);
1146-
}
1147-
}
1148-
panic!("Const array '{array_var}' can only be accessed with compile-time constant indices",);
1153+
let simplified_index = index
1154+
.iter()
1155+
.map(|idx| {
1156+
simplify_expr(idx, lines, counters, array_manager, const_malloc, const_arrays)
1157+
.as_constant()
1158+
.expect("Const array access index should be constant")
1159+
.naive_eval()
1160+
.expect("Const array access index should be constant")
1161+
.to_usize()
1162+
})
1163+
.collect::<Vec<_>>();
1164+
1165+
return SimpleExpr::Constant(ConstExpression::from(
1166+
arr.navigate(&simplified_index)
1167+
.expect("Const array access index out of bounds")
1168+
.as_scalar()
1169+
.expect("Const array access should return a scalar"),
1170+
));
11491171
}
11501172

1173+
assert_eq!(index.len(), 1);
1174+
let index = index[0].clone();
1175+
11511176
if let SimpleExpr::Var(array_var) = array
11521177
&& let Some(label) = const_malloc.map.get(array_var)
1153-
&& let Ok(mut offset) = ConstExpression::try_from(*index.clone())
1178+
&& let Ok(mut offset) = ConstExpression::try_from(index.clone())
11541179
{
11551180
offset = offset.try_naive_simplification();
11561181
return SimpleExpr::ConstMallocAccess {
@@ -1159,7 +1184,7 @@ fn simplify_expr(
11591184
};
11601185
}
11611186

1162-
let aux_arr = array_manager.get_aux_var(array, index); // auxiliary var to store m[array + index]
1187+
let aux_arr = array_manager.get_aux_var(array, &index); // auxiliary var to store m[array + index]
11631188

11641189
if !array_manager.valid.insert(aux_arr.clone()) {
11651190
return SimpleExpr::Var(aux_arr);
@@ -1169,7 +1194,7 @@ fn simplify_expr(
11691194
counters,
11701195
lines,
11711196
array.clone(),
1172-
index,
1197+
&[index],
11731198
ArrayAccessType::VarIsAssigned(aux_arr.clone()),
11741199
array_manager,
11751200
const_malloc,
@@ -1216,7 +1241,7 @@ fn simplify_expr(
12161241
/// Returns (internal_vars, external_vars)
12171242
pub fn find_variable_usage(
12181243
lines: &[Line],
1219-
const_arrays: &BTreeMap<String, Vec<usize>>,
1244+
const_arrays: &BTreeMap<String, ConstArrayValue>,
12201245
) -> (BTreeSet<Var>, BTreeSet<Var>) {
12211246
let mut internal_vars = BTreeSet::new();
12221247
let mut external_vars = BTreeSet::new();
@@ -1376,7 +1401,9 @@ fn inline_expr(expr: &mut Expression, args: &BTreeMap<Var, SimpleExpr>, inlining
13761401
}
13771402
Expression::ArrayAccess { array, index } => {
13781403
inline_simple_expr(array, args, inlining_count);
1379-
inline_expr(index, args, inlining_count);
1404+
for idx in index {
1405+
inline_expr(idx, args, inlining_count);
1406+
}
13801407
}
13811408
Expression::Binary { left, right, .. } => {
13821409
inline_expr(left, args, inlining_count);
@@ -1540,7 +1567,7 @@ fn inline_lines(
15401567
}
15411568
}
15421569

1543-
fn vars_in_expression(expr: &Expression, const_arrays: &BTreeMap<String, Vec<usize>>) -> BTreeSet<Var> {
1570+
fn vars_in_expression(expr: &Expression, const_arrays: &BTreeMap<String, ConstArrayValue>) -> BTreeSet<Var> {
15441571
let mut vars = BTreeSet::new();
15451572
match expr {
15461573
Expression::Value(value) => {
@@ -1554,7 +1581,9 @@ fn vars_in_expression(expr: &Expression, const_arrays: &BTreeMap<String, Vec<usi
15541581
{
15551582
vars.insert(array.clone());
15561583
}
1557-
vars.extend(vars_in_expression(index, const_arrays));
1584+
for idx in index {
1585+
vars.extend(vars_in_expression(idx, const_arrays));
1586+
}
15581587
}
15591588
Expression::Binary { left, right, .. } => {
15601589
vars.extend(vars_in_expression(left, const_arrays));
@@ -1580,40 +1609,46 @@ fn handle_array_assignment(
15801609
counters: &mut Counters,
15811610
res: &mut Vec<SimpleLine>,
15821611
array: SimpleExpr,
1583-
index: &Expression,
1612+
index: &[Expression],
15841613
access_type: ArrayAccessType,
15851614
array_manager: &mut ArrayManager,
15861615
const_malloc: &ConstMalloc,
1587-
const_arrays: &BTreeMap<String, Vec<usize>>,
1616+
const_arrays: &BTreeMap<String, ConstArrayValue>,
15881617
) {
1589-
let simplified_index = simplify_expr(index, res, counters, array_manager, const_malloc, const_arrays);
1618+
let simplified_index = index
1619+
.iter()
1620+
.map(|idx| simplify_expr(idx, res, counters, array_manager, const_malloc, const_arrays))
1621+
.collect::<Vec<_>>();
15901622

15911623
if let (ArrayAccessType::VarIsAssigned(var), SimpleExpr::Var(array_var)) = (&access_type, &array)
15921624
&& let Some(const_array) = const_arrays.get(array_var)
15931625
{
1594-
let index = simplified_index
1595-
.as_constant()
1596-
.expect("Const array access index should be constant")
1597-
.naive_eval()
1598-
.unwrap()
1599-
.to_usize();
1600-
assert!(
1601-
index < const_array.len(),
1602-
"Const array '{}' index {} out of bounds (length {})",
1603-
array_var,
1604-
index,
1605-
const_array.len()
1606-
);
1626+
let idx = simplified_index
1627+
.iter()
1628+
.map(|idx| {
1629+
idx.as_constant()
1630+
.expect("Const array access index should be constant")
1631+
.naive_eval()
1632+
.unwrap()
1633+
.to_usize()
1634+
})
1635+
.collect::<Vec<_>>();
1636+
let value = const_array
1637+
.navigate(&idx)
1638+
.expect("Const array access index out of bounds")
1639+
.as_scalar()
1640+
.expect("Const array access should return a scalar");
16071641
res.push(SimpleLine::Assignment {
16081642
var: var.clone().into(),
16091643
operation: HighLevelOperation::Add,
1610-
arg0: SimpleExpr::Constant(ConstExpression::from(const_array[index])),
1644+
arg0: SimpleExpr::Constant(ConstExpression::from(value)),
16111645
arg1: SimpleExpr::zero(),
16121646
});
16131647
return;
16141648
}
16151649

1616-
if let SimpleExpr::Constant(offset) = simplified_index.clone()
1650+
if simplified_index.len() == 1
1651+
&& let SimpleExpr::Constant(offset) = simplified_index[0].clone()
16171652
&& let SimpleExpr::Var(array_var) = &array
16181653
&& let Some(label) = const_malloc.map.get(array_var)
16191654
&& let ArrayAccessType::ArrayIsAssigned(Expression::Binary { left, operation, right }) = &access_type
@@ -1640,7 +1675,8 @@ fn handle_array_assignment(
16401675
};
16411676

16421677
// TODO opti: in some case we could use ConstMallocAccess
1643-
1678+
assert_eq!(simplified_index.len(), 1);
1679+
let simplified_index = simplified_index[0].clone();
16441680
let (index_var, shift) = match simplified_index {
16451681
SimpleExpr::Constant(c) => (array, c),
16461682
_ => {
@@ -1745,8 +1781,9 @@ fn replace_vars_for_unroll_in_expr(
17451781
*array_var = format!("@unrolled_{unroll_index}_{iterator_value}_{array_var}");
17461782
}
17471783
}
1748-
1749-
replace_vars_for_unroll_in_expr(index, iterator, unroll_index, iterator_value, internal_vars);
1784+
for index in index {
1785+
replace_vars_for_unroll_in_expr(index, iterator, unroll_index, iterator_value, internal_vars);
1786+
}
17501787
}
17511788
Expression::Binary { left, right, .. } => {
17521789
replace_vars_for_unroll_in_expr(left, iterator, unroll_index, iterator_value, internal_vars);
@@ -2163,7 +2200,7 @@ fn handle_const_arguments_helper(
21632200
lines: &mut [Line],
21642201
constant_functions: &BTreeMap<String, Function>,
21652202
new_functions: &mut BTreeMap<String, Function>,
2166-
const_arrays: &BTreeMap<String, Vec<usize>>,
2203+
const_arrays: &BTreeMap<String, ConstArrayValue>,
21672204
) -> bool {
21682205
let mut changed = false;
21692206
'outer: for line in lines {
@@ -2287,7 +2324,9 @@ fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap<Var, F>)
22872324
if let SimpleExpr::Var(array_var) = array {
22882325
assert!(!map.contains_key(array_var), "Array {array_var} is a constant");
22892326
}
2290-
replace_vars_by_const_in_expr(index, map);
2327+
for index in index {
2328+
replace_vars_by_const_in_expr(index, map);
2329+
}
22912330
}
22922331
Expression::Binary { left, right, .. } => {
22932332
replace_vars_by_const_in_expr(left, map);

crates/lean_compiler/src/grammar.pest

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ import_statement = { "import" ~ filepath ~ ";" }
88

99
// Constants
1010
constant_declaration = { "const" ~ identifier ~ "=" ~ (array_literal | expression) ~ ";" }
11-
array_literal = { "[" ~ (expression ~ ("," ~ expression)*)? ~ "]" }
11+
array_literal = { "[" ~ (array_element ~ ("," ~ array_element)*)? ~ "]" }
12+
array_element = { array_literal | expression }
1213

1314
// Functions
1415
function = { pragma? ~ "fn" ~ identifier ~ "(" ~ parameter_list? ~ ")" ~ inlined_statement? ~ return_count? ~ "{" ~ statement* ~ "}" }
@@ -102,8 +103,9 @@ primary = {
102103
log2_ceil_expr = { "log2_ceil" ~ "(" ~ expression ~ ")" }
103104
next_multiple_of_expr = { "next_multiple_of" ~ "(" ~ expression ~ "," ~ expression ~ ")" }
104105
saturating_sub_expr = { "saturating_sub" ~ "(" ~ expression ~ "," ~ expression ~ ")" }
105-
len_expr = { "len" ~ "(" ~ identifier ~ ")" }
106-
array_access_expr = { identifier ~ "[" ~ expression ~ "]" }
106+
len_expr = { "len" ~ "(" ~ len_argument ~ ")" }
107+
len_argument = { identifier ~ ("[" ~ expression ~ "]")* }
108+
array_access_expr = { identifier ~ ("[" ~ expression ~ "]")+ }
107109

108110
// Basic elements
109111
var_or_constant = { constant_value | identifier }

0 commit comments

Comments
 (0)