Skip to content

Commit 1d6248c

Browse files
committed
zkDSL: better support for len() on multi-dimensional array
1 parent d40b5c4 commit 1d6248c

File tree

4 files changed

+121
-33
lines changed

4 files changed

+121
-33
lines changed

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,11 @@ fn check_expr_scoping(expr: &Expression, ctx: &Context) {
529529
check_expr_scoping(arg, ctx);
530530
}
531531
}
532+
Expression::Len { indices, .. } => {
533+
for idx in indices {
534+
check_expr_scoping(idx, ctx);
535+
}
536+
}
532537
}
533538
}
534539

@@ -704,7 +709,7 @@ fn simplify_lines(
704709
});
705710
}
706711
}
707-
Expression::MathExpr(_, _) => unreachable!(),
712+
Expression::MathExpr(_, _) | Expression::Len { .. } => unreachable!(),
708713
Expression::FunctionCall { .. } => {
709714
let result = simplify_expr(
710715
value,
@@ -1467,6 +1472,7 @@ fn simplify_expr(
14671472

14681473
SimpleExpr::Var(result_var)
14691474
}
1475+
Expression::Len { .. } => unreachable!(),
14701476
}
14711477
}
14721478

@@ -1651,6 +1657,11 @@ fn inline_expr(expr: &mut Expression, args: &BTreeMap<Var, SimpleExpr>, inlining
16511657
inline_expr(arg, args, inlining_count);
16521658
}
16531659
}
1660+
Expression::Len { indices, .. } => {
1661+
for idx in indices {
1662+
inline_expr(idx, args, inlining_count);
1663+
}
1664+
}
16541665
}
16551666
}
16561667

@@ -1836,6 +1847,11 @@ fn vars_in_expression(expr: &Expression, const_arrays: &BTreeMap<String, ConstAr
18361847
vars.extend(vars_in_expression(arg, const_arrays));
18371848
}
18381849
}
1850+
Expression::Len { indices, .. } => {
1851+
for idx in indices {
1852+
vars.extend(vars_in_expression(idx, const_arrays));
1853+
}
1854+
}
18391855
}
18401856
vars
18411857
}
@@ -2064,6 +2080,11 @@ fn replace_vars_for_unroll_in_expr(
20642080
replace_vars_for_unroll_in_expr(arg, iterator, unroll_index, iterator_value, internal_vars);
20652081
}
20662082
}
2083+
Expression::Len { indices, .. } => {
2084+
for idx in indices {
2085+
replace_vars_for_unroll_in_expr(idx, iterator, unroll_index, iterator_value, internal_vars);
2086+
}
2087+
}
20672088
}
20682089
}
20692090

@@ -2369,6 +2390,15 @@ fn extract_inlined_calls_from_expr(
23692390
*expr = Expression::Value(SimpleExpr::Var(aux_var));
23702391
}
23712392
}
2393+
Expression::Len { indices, .. } => {
2394+
for idx in indices.iter_mut() {
2395+
lines.extend(extract_inlined_calls_from_expr(
2396+
idx,
2397+
inlined_functions,
2398+
inlined_var_counter,
2399+
));
2400+
}
2401+
}
23722402
}
23732403

23742404
lines
@@ -2849,6 +2879,11 @@ fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap<Var, F>)
28492879
replace_vars_by_const_in_expr(arg, map);
28502880
}
28512881
}
2882+
Expression::Len { indices, .. } => {
2883+
for idx in indices {
2884+
replace_vars_by_const_in_expr(idx, map);
2885+
}
2886+
}
28522887
}
28532888
}
28542889

crates/lean_compiler/src/lang.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ impl TryFrom<Expression> for ConstExpression {
154154
Ok(Self::MathExpr(math_expr, const_args))
155155
}
156156
Expression::FunctionCall { .. } => Err(()),
157+
Expression::Len { .. } => Err(()),
157158
}
158159
}
159160
}
@@ -260,6 +261,10 @@ pub enum Expression {
260261
function_name: String,
261262
args: Vec<Self>,
262263
},
264+
Len {
265+
array: String,
266+
indices: Vec<Self>,
267+
},
263268
}
264269

265270
/// For arbitrary compile-time computations
@@ -326,6 +331,17 @@ impl From<Var> for Expression {
326331

327332
impl Expression {
328333
pub fn naive_eval(&self, const_arrays: &BTreeMap<String, ConstArrayValue>) -> Option<F> {
334+
// Handle Len specially since it needs const_arrays
335+
if let Self::Len { array, indices } = self {
336+
let idx: Option<Vec<_>> = indices
337+
.iter()
338+
.map(|e| e.naive_eval(const_arrays).map(|f| f.to_usize()))
339+
.collect();
340+
let idx = idx?;
341+
let arr = const_arrays.get(array)?;
342+
let target = arr.navigate(&idx)?;
343+
return Some(F::from_usize(target.len()));
344+
}
329345
self.eval_with(
330346
&|value: &SimpleExpr| value.as_constant()?.naive_eval(),
331347
&|arr, indexes| {
@@ -366,6 +382,7 @@ impl Expression {
366382
Some(math_expr.eval(&eval_args))
367383
}
368384
Self::FunctionCall { .. } => None,
385+
Self::Len { .. } => None, // Handled directly in naive_eval
369386
}
370387
}
371388

@@ -514,6 +531,10 @@ impl Display for Expression {
514531
let args_str = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>().join(", ");
515532
write!(f, "{function_name}({args_str})")
516533
}
534+
Self::Len { array, indices } => {
535+
let indices_str = indices.iter().map(|i| format!("[{i}]")).collect::<Vec<_>>().join("");
536+
write!(f, "len({array}{indices_str})")
537+
}
517538
}
518539
}
519540
}

crates/lean_compiler/src/parser/parsers/expression.rs

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -165,43 +165,53 @@ impl Parse<Expression> for LenParser {
165165
index_exprs.push(ExpressionParser.parse(index_pair, ctx)?);
166166
}
167167

168-
// Now evaluate the indices
168+
// Try to evaluate indices at parse time
169169
let mut indices = Vec::new();
170+
let mut all_const = true;
170171
for index_expr in &index_exprs {
171-
let index_val = evaluate_const_expr(index_expr, ctx).ok_or_else(|| {
172-
SemanticError::with_context("Index in len() must be a compile-time constant", "len expression")
173-
})?;
174-
indices.push(index_val);
172+
if let Some(index_val) = evaluate_const_expr(index_expr, ctx) {
173+
indices.push(index_val);
174+
} else {
175+
all_const = false;
176+
break;
177+
}
175178
}
176179

177-
// Now get the array again and navigate to the target sub-array
178-
let base_array = ctx.get_const_array(&ident).unwrap();
179-
let target = if indices.is_empty() {
180-
base_array
181-
} else {
182-
base_array.navigate(&indices).ok_or_else(|| {
183-
SemanticError::with_context(
184-
format!(
185-
"len() index out of bounds for '{ident}': [{}]",
186-
indices.iter().map(|i| i.to_string()).collect::<Vec<_>>().join("][")
187-
),
188-
"len expression",
189-
)
190-
})?
191-
};
180+
// If all indices are constants, evaluate len() now
181+
if all_const {
182+
let base_array = ctx.get_const_array(&ident).unwrap();
183+
let target = if indices.is_empty() {
184+
base_array
185+
} else {
186+
base_array.navigate(&indices).ok_or_else(|| {
187+
SemanticError::with_context(
188+
format!(
189+
"len() index out of bounds for '{ident}': [{}]",
190+
indices.iter().map(|i| i.to_string()).collect::<Vec<_>>().join("][")
191+
),
192+
"len expression",
193+
)
194+
})?
195+
};
192196

193-
// Get its length
194-
let length = match target {
195-
ConstArrayValue::Scalar(_) => {
196-
return Err(
197-
SemanticError::with_context("Cannot call len() on a scalar value", "len expression").into(),
198-
);
199-
}
200-
ConstArrayValue::Array(arr) => arr.len(),
201-
};
197+
let length = match target {
198+
ConstArrayValue::Scalar(_) => {
199+
return Err(
200+
SemanticError::with_context("Cannot call len() on a scalar value", "len expression").into(),
201+
);
202+
}
203+
ConstArrayValue::Array(arr) => arr.len(),
204+
};
202205

203-
Ok(Expression::Value(SimpleExpr::Constant(ConstExpression::Value(
204-
ConstantValue::Scalar(length),
205-
))))
206+
Ok(Expression::Value(SimpleExpr::Constant(ConstExpression::Value(
207+
ConstantValue::Scalar(length),
208+
))))
209+
} else {
210+
// Defer evaluation - return Expression::Len
211+
Ok(Expression::Len {
212+
array: ident,
213+
indices: index_exprs,
214+
})
215+
}
206216
}
207217
}

crates/lean_compiler/tests/test_compiler.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,3 +1376,25 @@ fn test_nested_function_call() {
13761376
false,
13771377
);
13781378
}
1379+
1380+
#[test]
1381+
fn test_len_2d_array() {
1382+
let program = r#"
1383+
const ARR = [[1], [7, 3], [7]];
1384+
const N = 2 + len(ARR[0]);
1385+
fn main() {
1386+
for i in 0..N unroll {
1387+
for j in 0..len(ARR[i]) unroll {
1388+
assert j * (j - 1) == 0;
1389+
}
1390+
}
1391+
return;
1392+
}
1393+
"#;
1394+
compile_and_run(
1395+
&ProgramSource::Raw(program.to_string()),
1396+
(&[], &[]),
1397+
DEFAULT_NO_VEC_RUNTIME_MEMORY,
1398+
false,
1399+
);
1400+
}

0 commit comments

Comments
 (0)