Skip to content

Commit 3ceaf1b

Browse files
committed
zkDSL new feature: constant arrays (compile time)
1 parent 813b005 commit 3ceaf1b

File tree

8 files changed

+345
-130
lines changed

8 files changed

+345
-130
lines changed

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 142 additions & 96 deletions
Large diffs are not rendered by default.

crates/lean_compiler/src/grammar.pest

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ WHITESPACE = _{ " " | "\t" | "\n" | "\r" }
44
program = { SOI ~ constant_declaration* ~ function+ ~ EOI }
55

66
// Constants
7-
constant_declaration = { "const" ~ identifier ~ "=" ~ expression ~ ";" }
7+
constant_declaration = { "const" ~ identifier ~ "=" ~ (array_literal | expression) ~ ";" }
8+
array_literal = { "[" ~ (expression ~ ("," ~ expression)*)? ~ "]" }
89

910
// Functions
1011
function = { pragma? ~ "fn" ~ identifier ~ "(" ~ parameter_list? ~ ")" ~ inlined_statement? ~ return_count? ~ "{" ~ statement* ~ "}" }
@@ -89,17 +90,19 @@ primary = {
8990
"(" ~ expression ~ ")" |
9091
log2_ceil_expr |
9192
next_multiple_of_expr |
93+
len_expr |
9294
array_access_expr |
9395
var_or_constant
9496
}
9597
log2_ceil_expr = { "log2_ceil" ~ "(" ~ expression ~ ")" }
9698
next_multiple_of_expr = { "next_multiple_of" ~ "(" ~ expression ~ "," ~ expression ~ ")" }
99+
len_expr = { "len" ~ "(" ~ identifier ~ ")" }
97100
array_access_expr = { identifier ~ "[" ~ expression ~ "]" }
98101

99102
// Basic elements
100103
var_or_constant = { constant_value | identifier }
101-
constant_value = { number | "public_input_start" }
104+
constant_value = { number | "public_input_start" | "pointer_to_zero_vector" | "pointer_to_one_vector" }
102105

103106
// Lexical elements
104107
identifier = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* }
105-
number = @{ ASCII_DIGIT+ }
108+
number = @{ ASCII_DIGIT+ }

crates/lean_compiler/src/lang.rs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::{F, ir::HighLevelOperation};
1010
#[derive(Debug, Clone)]
1111
pub struct Program {
1212
pub functions: BTreeMap<String, Function>,
13+
pub const_arrays: BTreeMap<String, Vec<usize>>,
1314
}
1415

1516
#[derive(Debug, Clone)]
@@ -288,8 +289,18 @@ impl From<Var> for Expression {
288289
}
289290

290291
impl Expression {
291-
pub fn naive_eval(&self) -> Option<F> {
292-
self.eval_with(&|value: &SimpleExpr| value.as_constant()?.naive_eval(), &|_, _| None)
292+
pub fn naive_eval(&self, const_arrays: &BTreeMap<String, Vec<usize>>) -> Option<F> {
293+
self.eval_with(
294+
&|value: &SimpleExpr| value.as_constant()?.naive_eval(),
295+
&|arr, index| {
296+
let SimpleExpr::Var(name) = arr else {
297+
return None;
298+
};
299+
let index_usize = index.to_usize();
300+
let array = const_arrays.get(name)?;
301+
Some(F::from_usize(*array.get(index_usize)?))
302+
},
303+
)
293304
}
294305

295306
pub fn eval_with<ValueFn, ArrayFn>(&self, value_fn: &ValueFn, array_fn: &ArrayFn) -> Option<F>
@@ -417,10 +428,15 @@ pub enum Line {
417428
pub struct Context {
418429
/// A list of lexical scopes, innermost scope last.
419430
pub scopes: Vec<Scope>,
431+
/// A mapping from constant array names to their values.
432+
pub const_arrays: BTreeMap<String, Vec<usize>>,
420433
}
421434

422435
impl Context {
423436
pub fn defines(&self, var: &Var) -> bool {
437+
if self.const_arrays.contains_key(var) {
438+
return true;
439+
}
424440
for scope in self.scopes.iter() {
425441
if scope.vars.contains(var) {
426442
return true;
@@ -680,7 +696,19 @@ impl Display for Line {
680696

681697
impl Display for Program {
682698
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
683-
let mut first = true;
699+
// Print const arrays
700+
for (name, values) in &self.const_arrays {
701+
write!(f, "const {name} = [")?;
702+
for (i, v) in values.iter().enumerate() {
703+
if i > 0 {
704+
write!(f, ", ")?;
705+
}
706+
write!(f, "{v}")?;
707+
}
708+
writeln!(f, "];")?;
709+
}
710+
711+
let mut first = self.const_arrays.is_empty();
684712
for function in self.functions.values() {
685713
if !first {
686714
writeln!(f)?;

crates/lean_compiler/src/parser/parsers/expression.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use super::literal::VarOrConstantParser;
22
use super::{Parse, ParseContext, next_inner_pair};
33
use crate::{
44
ir::HighLevelOperation,
5-
lang::Expression,
5+
lang::{ConstExpression, ConstantValue, Expression, SimpleExpr},
66
parser::{
77
error::{ParseError, ParseResult, SemanticError},
88
grammar::{ParsePair, Rule},
@@ -76,6 +76,7 @@ impl Parse<Expression> for PrimaryExpressionParser {
7676
Rule::array_access_expr => ArrayAccessParser::parse(inner, ctx),
7777
Rule::log2_ceil_expr => Log2CeilParser::parse(inner, ctx),
7878
Rule::next_multiple_of_expr => NextMultipleOfParser::parse(inner, ctx),
79+
Rule::len_expr => LenParser::parse(inner, ctx),
7980
_ => Err(SemanticError::new("Invalid primary expression").into()),
8081
}
8182
}
@@ -124,3 +125,25 @@ impl Parse<Expression> for NextMultipleOfParser {
124125
})
125126
}
126127
}
128+
129+
/// Parser for len() expressions on const arrays.
130+
pub struct LenParser;
131+
132+
impl Parse<Expression> for LenParser {
133+
fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult<Expression> {
134+
let mut inner = pair.into_inner();
135+
let ident = next_inner_pair(&mut inner, "len argument")?.as_str();
136+
137+
if let Some(arr) = ctx.get_const_array(ident) {
138+
Ok(Expression::Value(SimpleExpr::Constant(ConstExpression::Value(
139+
ConstantValue::Scalar(arr.len()),
140+
))))
141+
} else {
142+
Err(SemanticError::with_context(
143+
format!("len() argument '{ident}' is not a const array"),
144+
"len expression",
145+
)
146+
.into())
147+
}
148+
}
149+
}

crates/lean_compiler/src/parser/parsers/literal.rs

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::expression::ExpressionParser;
2-
use super::{Parse, ParseContext, next_inner_pair};
2+
use super::{Parse, ParseContext, ParsedConstant, next_inner_pair};
33
use crate::{
44
F,
55
lang::{ConstExpression, ConstantValue, SimpleExpr},
@@ -14,30 +14,61 @@ use utils::ToUsize;
1414
/// Parser for constant declarations.
1515
pub struct ConstantDeclarationParser;
1616

17-
impl Parse<(String, usize)> for ConstantDeclarationParser {
18-
fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult<(String, usize)> {
17+
impl Parse<(String, ParsedConstant)> for ConstantDeclarationParser {
18+
fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult<(String, ParsedConstant)> {
1919
let mut inner = pair.into_inner();
2020
let name = next_inner_pair(&mut inner, "constant name")?.as_str().to_string();
2121
let value_pair = next_inner_pair(&mut inner, "constant value")?;
2222

23-
// Parse the expression and evaluate it
24-
let expr = ExpressionParser::parse(value_pair, ctx)?;
23+
match value_pair.as_rule() {
24+
Rule::array_literal => {
25+
let values: Vec<usize> = value_pair
26+
.into_inner()
27+
.map(|expr_pair| {
28+
let expr = ExpressionParser::parse(expr_pair, ctx).unwrap();
29+
expr.eval_with(
30+
&|simple_expr| match simple_expr {
31+
SimpleExpr::Constant(cst) => cst.naive_eval(),
32+
SimpleExpr::Var(var) => ctx.get_constant(var).map(F::from_usize),
33+
SimpleExpr::ConstMallocAccess { .. } => None,
34+
},
35+
&|_, _| None,
36+
)
37+
.ok_or_else(|| {
38+
SemanticError::with_context(
39+
format!("Failed to evaluate array element in constant: {name}"),
40+
"constant declaration",
41+
)
42+
})
43+
.map(|f| f.to_usize())
44+
})
45+
.collect::<Result<Vec<_>, _>>()?;
46+
Ok((name, ParsedConstant::Array(values)))
47+
}
48+
_ => {
49+
// Parse the expression and evaluate it
50+
let expr = ExpressionParser::parse(value_pair, ctx)?;
2551

26-
let value = expr
27-
.eval_with(
28-
&|simple_expr| match simple_expr {
29-
SimpleExpr::Constant(cst) => cst.naive_eval(),
30-
SimpleExpr::Var(var) => ctx.get_constant(var).map(F::from_usize),
31-
SimpleExpr::ConstMallocAccess { .. } => None, // Not allowed in constants
32-
},
33-
&|_, _| None,
34-
)
35-
.ok_or_else(|| {
36-
SemanticError::with_context(format!("Failed to evaluate constant: {name}"), "constant declaration")
37-
})?
38-
.to_usize();
52+
let value = expr
53+
.eval_with(
54+
&|simple_expr| match simple_expr {
55+
SimpleExpr::Constant(cst) => cst.naive_eval(),
56+
SimpleExpr::Var(var) => ctx.get_constant(var).map(F::from_usize),
57+
SimpleExpr::ConstMallocAccess { .. } => None,
58+
},
59+
&|_, _| None,
60+
)
61+
.ok_or_else(|| {
62+
SemanticError::with_context(
63+
format!("Failed to evaluate constant: {name}"),
64+
"constant declaration",
65+
)
66+
})?
67+
.to_usize();
3968

40-
Ok((name, value))
69+
Ok((name, ParsedConstant::Scalar(value)))
70+
}
71+
}
4172
}
4273
}
4374

@@ -73,6 +104,15 @@ impl VarOrConstantParser {
73104
ConstantValue::PointerToOneVector,
74105
))),
75106
_ => {
107+
// Check if it's a const array (error case - can't use array as value)
108+
if ctx.get_const_array(text).is_some() {
109+
return Err(SemanticError::with_context(
110+
format!("Cannot use const array '{text}' as a value directly (use indexing or len())"),
111+
"variable reference",
112+
)
113+
.into());
114+
}
115+
76116
// Try to resolve as defined constant
77117
if let Some(value) = ctx.get_constant(text) {
78118
Ok(SimpleExpr::Constant(ConstExpression::Value(ConstantValue::Scalar(

crates/lean_compiler/src/parser/parsers/mod.rs

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,20 @@ pub mod literal;
1010
pub mod program;
1111
pub mod statement;
1212

13+
/// Represents a parsed constant value (scalar or array).
14+
#[derive(Debug, Clone)]
15+
pub enum ParsedConstant {
16+
Scalar(usize),
17+
Array(Vec<usize>),
18+
}
19+
1320
/// Core parsing context that all parsers share.
1421
#[derive(Debug)]
1522
pub struct ParseContext {
16-
/// Compile-time constants defined in the program
23+
/// Compile-time scalar constants defined in the program
1724
pub constants: BTreeMap<String, usize>,
25+
/// Compile-time array constants defined in the program
26+
pub const_arrays: BTreeMap<String, Vec<usize>>,
1827
/// Counter for generating unique trash variable names
1928
pub trash_var_count: usize,
2029
}
@@ -23,27 +32,47 @@ impl ParseContext {
2332
pub const fn new() -> Self {
2433
Self {
2534
constants: BTreeMap::new(),
35+
const_arrays: BTreeMap::new(),
2636
trash_var_count: 0,
2737
}
2838
}
2939

30-
/// Adds a constant to the context.
40+
/// Adds a scalar constant to the context.
3141
pub fn add_constant(&mut self, name: String, value: usize) -> Result<(), SemanticError> {
32-
if self.constants.insert(name.clone(), value).is_some() {
42+
if self.constants.contains_key(&name) || self.const_arrays.contains_key(&name) {
43+
Err(SemanticError::with_context(
44+
format!("Defined multiple times: {name}"),
45+
"constant declaration",
46+
))
47+
} else {
48+
self.constants.insert(name, value);
49+
Ok(())
50+
}
51+
}
52+
53+
/// Adds an array constant to the context.
54+
pub fn add_const_array(&mut self, name: String, values: Vec<usize>) -> Result<(), SemanticError> {
55+
if self.constants.contains_key(&name) || self.const_arrays.contains_key(&name) {
3356
Err(SemanticError::with_context(
34-
format!("Multiply defined constant: {name}"),
57+
format!("Defined multiple times: {name}"),
3558
"constant declaration",
3659
))
3760
} else {
61+
self.const_arrays.insert(name, values);
3862
Ok(())
3963
}
4064
}
4165

42-
/// Looks up a constant value.
66+
/// Looks up a scalar constant value.
4367
pub fn get_constant(&self, name: &str) -> Option<usize> {
4468
self.constants.get(name).copied()
4569
}
4670

71+
/// Looks up an array constant.
72+
pub fn get_const_array(&self, name: &str) -> Option<&Vec<usize>> {
73+
self.const_arrays.get(name)
74+
}
75+
4776
/// Generates a unique trash variable name.
4877
pub fn next_trash_var(&mut self) -> String {
4978
self.trash_var_count += 1;

crates/lean_compiler/src/parser/parsers/program.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::function::FunctionParser;
22
use super::literal::ConstantDeclarationParser;
3-
use super::{Parse, ParseContext};
3+
use super::{Parse, ParseContext, ParsedConstant};
44
use crate::{
55
lang::Program,
66
parser::{
@@ -23,7 +23,10 @@ impl Parse<(Program, BTreeMap<usize, String>)> for ProgramParser {
2323
match item.as_rule() {
2424
Rule::constant_declaration => {
2525
let (name, value) = ConstantDeclarationParser::parse(item, &mut ctx)?;
26-
ctx.add_constant(name, value)?;
26+
match value {
27+
ParsedConstant::Scalar(v) => ctx.add_constant(name, v)?,
28+
ParsedConstant::Array(arr) => ctx.add_const_array(name, arr)?,
29+
}
2730
}
2831
Rule::function => {
2932
let location = item.line_col().0;
@@ -45,6 +48,12 @@ impl Parse<(Program, BTreeMap<usize, String>)> for ProgramParser {
4548
}
4649
}
4750

48-
Ok((Program { functions }, function_locations))
51+
Ok((
52+
Program {
53+
functions,
54+
const_arrays: ctx.const_arrays,
55+
},
56+
function_locations,
57+
))
4958
}
5059
}

crates/lean_compiler/tests/test_compiler.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,3 +729,40 @@ fn test_next_multiple_of() {
729729
"#;
730730
compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
731731
}
732+
733+
#[test]
734+
fn test_const_array() {
735+
let program = r#"
736+
const FIVE = 5;
737+
const ARR = [4, FIVE, 4 + 2, 3 * 2 + 1];
738+
fn main() {
739+
for i in 1..len(ARR) unroll {
740+
x = i + 4;
741+
assert ARR[i] == x;
742+
}
743+
four = 4;
744+
assert len(ARR) == four;
745+
res = func(2);
746+
six = 6;
747+
assert res == six;
748+
nothing(ARR[0]);
749+
mem_arr = malloc(len(ARR));
750+
for i in 0..len(ARR) unroll {
751+
mem_arr[i] = ARR[i];
752+
}
753+
for i in 0..ARR[0] {
754+
print(2**ARR[0]);
755+
}
756+
print(2**ARR[1]);
757+
return;
758+
}
759+
760+
fn func(const x) -> 1 {
761+
return ARR[x];
762+
}
763+
fn nothing(x) {
764+
return;
765+
}
766+
"#;
767+
compile_and_run(program.to_string(), (&[], &[]), DEFAULT_NO_VEC_RUNTIME_MEMORY, false);
768+
}

0 commit comments

Comments
 (0)