Skip to content

Commit 85038a5

Browse files
committed
zkDSL compiler: use a single unified MathExpr enum instead of custom implems for each formula (making the addition of new formulas simpler)
1 parent 35b06d5 commit 85038a5

File tree

10 files changed

+202
-225
lines changed

10 files changed

+202
-225
lines changed

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 31 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -531,12 +531,10 @@ fn check_expr_scoping(expr: &Expression, ctx: &Context) {
531531
check_expr_scoping(left, ctx);
532532
check_expr_scoping(right, ctx);
533533
}
534-
Expression::Log2Ceil { value } => {
535-
check_expr_scoping(value, ctx);
536-
}
537-
Expression::NextMultipleOf { value, multiple } => {
538-
check_expr_scoping(value, ctx);
539-
check_expr_scoping(multiple, ctx);
534+
Expression::MathExpr(_, args) => {
535+
for arg in args {
536+
check_expr_scoping(arg, ctx);
537+
}
540538
}
541539
}
542540
}
@@ -670,7 +668,7 @@ fn simplify_lines(
670668
arg1: right,
671669
});
672670
}
673-
Expression::Log2Ceil { .. } | Expression::NextMultipleOf { .. } => unreachable!(),
671+
Expression::MathExpr(_, _) => unreachable!(),
674672
},
675673
Line::ArrayAssign { array, index, value } => {
676674
handle_array_assignment(
@@ -1220,25 +1218,16 @@ fn simplify_expr(
12201218
});
12211219
SimpleExpr::Var(aux_var)
12221220
}
1223-
Expression::Log2Ceil { value } => {
1224-
let const_value = simplify_expr(value, lines, counters, array_manager, const_malloc, const_arrays)
1225-
.as_constant()
1226-
.unwrap();
1227-
SimpleExpr::Constant(ConstExpression::Log2Ceil {
1228-
value: Box::new(const_value),
1229-
})
1230-
}
1231-
Expression::NextMultipleOf { value, multiple } => {
1232-
let const_value = simplify_expr(value, lines, counters, array_manager, const_malloc, const_arrays)
1233-
.as_constant()
1234-
.unwrap();
1235-
let const_multiple = simplify_expr(multiple, lines, counters, array_manager, const_malloc, const_arrays)
1236-
.as_constant()
1237-
.unwrap();
1238-
SimpleExpr::Constant(ConstExpression::NextMultipleOf {
1239-
value: Box::new(const_value),
1240-
multiple: Box::new(const_multiple),
1241-
})
1221+
Expression::MathExpr(formula, args) => {
1222+
let simplified_args = args
1223+
.iter()
1224+
.map(|arg| {
1225+
simplify_expr(arg, lines, counters, array_manager, const_malloc, const_arrays)
1226+
.as_constant()
1227+
.unwrap()
1228+
})
1229+
.collect::<Vec<_>>();
1230+
SimpleExpr::Constant(ConstExpression::MathExpr(*formula, simplified_args))
12421231
}
12431232
}
12441233
}
@@ -1418,12 +1407,10 @@ fn inline_expr(expr: &mut Expression, args: &BTreeMap<Var, SimpleExpr>, inlining
14181407
inline_expr(left, args, inlining_count);
14191408
inline_expr(right, args, inlining_count);
14201409
}
1421-
Expression::Log2Ceil { value } => {
1422-
inline_expr(value, args, inlining_count);
1423-
}
1424-
Expression::NextMultipleOf { value, multiple } => {
1425-
inline_expr(value, args, inlining_count);
1426-
inline_expr(multiple, args, inlining_count);
1410+
Expression::MathExpr(_, math_args) => {
1411+
for arg in math_args {
1412+
inline_expr(arg, args, inlining_count);
1413+
}
14271414
}
14281415
}
14291416
}
@@ -1604,12 +1591,10 @@ fn vars_in_expression(expr: &Expression, const_arrays: &BTreeMap<String, Vec<usi
16041591
vars.extend(vars_in_expression(left, const_arrays));
16051592
vars.extend(vars_in_expression(right, const_arrays));
16061593
}
1607-
Expression::Log2Ceil { value } => {
1608-
vars.extend(vars_in_expression(value, const_arrays));
1609-
}
1610-
Expression::NextMultipleOf { value, multiple } => {
1611-
vars.extend(vars_in_expression(value, const_arrays));
1612-
vars.extend(vars_in_expression(multiple, const_arrays));
1594+
Expression::MathExpr(_, args) => {
1595+
for arg in args {
1596+
vars.extend(vars_in_expression(arg, const_arrays));
1597+
}
16131598
}
16141599
}
16151600
vars
@@ -1772,12 +1757,10 @@ fn replace_vars_for_unroll_in_expr(
17721757
replace_vars_for_unroll_in_expr(left, iterator, unroll_index, iterator_value, internal_vars);
17731758
replace_vars_for_unroll_in_expr(right, iterator, unroll_index, iterator_value, internal_vars);
17741759
}
1775-
Expression::Log2Ceil { value } => {
1776-
replace_vars_for_unroll_in_expr(value, iterator, unroll_index, iterator_value, internal_vars);
1777-
}
1778-
Expression::NextMultipleOf { value, multiple } => {
1779-
replace_vars_for_unroll_in_expr(value, iterator, unroll_index, iterator_value, internal_vars);
1780-
replace_vars_for_unroll_in_expr(multiple, iterator, unroll_index, iterator_value, internal_vars);
1760+
Expression::MathExpr(_, args) => {
1761+
for arg in args {
1762+
replace_vars_for_unroll_in_expr(arg, iterator, unroll_index, iterator_value, internal_vars);
1763+
}
17811764
}
17821765
}
17831766
}
@@ -2304,12 +2287,10 @@ fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap<Var, F>)
23042287
replace_vars_by_const_in_expr(left, map);
23052288
replace_vars_by_const_in_expr(right, map);
23062289
}
2307-
Expression::Log2Ceil { value } => {
2308-
replace_vars_by_const_in_expr(value, map);
2309-
}
2310-
Expression::NextMultipleOf { value, multiple } => {
2311-
replace_vars_by_const_in_expr(value, map);
2312-
replace_vars_by_const_in_expr(multiple, map);
2290+
Expression::MathExpr(_, args) => {
2291+
for arg in args {
2292+
replace_vars_by_const_in_expr(arg, map);
2293+
}
23132294
}
23142295
}
23152296
}

crates/lean_compiler/src/lang.rs

Lines changed: 69 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,7 @@ pub enum ConstExpression {
115115
operation: HighLevelOperation,
116116
right: Box<Self>,
117117
},
118-
Log2Ceil {
119-
value: Box<Self>,
120-
},
121-
NextMultipleOf {
122-
value: Box<Self>,
123-
multiple: Box<Self>,
124-
},
118+
MathExpr(MathExpr, Vec<Self>),
125119
}
126120

127121
impl From<usize> for ConstExpression {
@@ -147,19 +141,12 @@ impl TryFrom<Expression> for ConstExpression {
147141
right: Box::new(right_expr),
148142
})
149143
}
150-
Expression::Log2Ceil { value } => {
151-
let value_expr = Self::try_from(*value)?;
152-
Ok(Self::Log2Ceil {
153-
value: Box::new(value_expr),
154-
})
155-
}
156-
Expression::NextMultipleOf { value, multiple } => {
157-
let value_expr = Self::try_from(*value)?;
158-
let multiple_expr = Self::try_from(*multiple)?;
159-
Ok(Self::NextMultipleOf {
160-
value: Box::new(value_expr),
161-
multiple: Box::new(multiple_expr),
162-
})
144+
Expression::MathExpr(math_expr, args) => {
145+
let mut const_args = Vec::new();
146+
for arg in args {
147+
const_args.push(Self::try_from(arg)?);
148+
}
149+
Ok(Self::MathExpr(math_expr, const_args))
163150
}
164151
}
165152
}
@@ -194,17 +181,12 @@ impl ConstExpression {
194181
Self::Binary { left, operation, right } => {
195182
Some(operation.eval(left.eval_with(func)?, right.eval_with(func)?))
196183
}
197-
Self::Log2Ceil { value } => {
198-
let value = value.eval_with(func)?;
199-
Some(F::from_usize(log2_ceil_usize(value.to_usize())))
200-
}
201-
Self::NextMultipleOf { value, multiple } => {
202-
let value = value.eval_with(func)?;
203-
let multiple = multiple.eval_with(func)?;
204-
let value_usize = value.to_usize();
205-
let multiple_usize = multiple.to_usize();
206-
let res = value_usize.next_multiple_of(multiple_usize);
207-
Some(F::from_usize(res))
184+
Self::MathExpr(math_expr, args) => {
185+
let mut eval_args = Vec::new();
186+
for arg in args {
187+
eval_args.push(arg.eval_with(func)?);
188+
}
189+
Some(math_expr.eval(&eval_args))
208190
}
209191
}
210192
}
@@ -267,13 +249,50 @@ pub enum Expression {
267249
operation: HighLevelOperation,
268250
right: Box<Self>,
269251
},
270-
Log2Ceil {
271-
value: Box<Expression>,
272-
}, // only for const expressions
273-
NextMultipleOf {
274-
value: Box<Expression>,
275-
multiple: Box<Expression>,
276-
}, // only for const expressions
252+
MathExpr(MathExpr, Vec<Expression>),
253+
}
254+
255+
/// For arbitrary compile-time computations
256+
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
257+
pub enum MathExpr {
258+
Log2Ceil,
259+
NextMultipleOf,
260+
}
261+
262+
impl Display for MathExpr {
263+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
264+
match self {
265+
Self::Log2Ceil => write!(f, "log2_ceil"),
266+
Self::NextMultipleOf => write!(f, "next_multiple_of"),
267+
}
268+
}
269+
}
270+
271+
impl MathExpr {
272+
pub fn num_args(&self) -> usize {
273+
match self {
274+
Self::Log2Ceil => 1,
275+
Self::NextMultipleOf => 2,
276+
}
277+
}
278+
pub fn eval(&self, args: &[F]) -> F {
279+
match self {
280+
Self::Log2Ceil => {
281+
assert_eq!(args.len(), 1);
282+
let value = args[0];
283+
F::from_usize(log2_ceil_usize(value.to_usize()))
284+
}
285+
Self::NextMultipleOf => {
286+
assert_eq!(args.len(), 2);
287+
let value = args[0];
288+
let multiple = args[1];
289+
let value_usize = value.to_usize();
290+
let multiple_usize = multiple.to_usize();
291+
let res = value_usize.next_multiple_of(multiple_usize);
292+
F::from_usize(res)
293+
}
294+
}
295+
}
277296
}
278297

279298
impl From<SimpleExpr> for Expression {
@@ -315,17 +334,12 @@ impl Expression {
315334
left.eval_with(value_fn, array_fn)?,
316335
right.eval_with(value_fn, array_fn)?,
317336
)),
318-
Self::Log2Ceil { value } => {
319-
let value = value.eval_with(value_fn, array_fn)?;
320-
Some(F::from_usize(log2_ceil_usize(value.to_usize())))
321-
}
322-
Self::NextMultipleOf { value, multiple } => {
323-
let value = value.eval_with(value_fn, array_fn)?;
324-
let multiple = multiple.eval_with(value_fn, array_fn)?;
325-
let value_usize = value.to_usize();
326-
let multiple_usize = multiple.to_usize();
327-
let res = value_usize.next_multiple_of(multiple_usize);
328-
Some(F::from_usize(res))
337+
Self::MathExpr(math_expr, args) => {
338+
let mut eval_args = Vec::new();
339+
for arg in args {
340+
eval_args.push(arg.eval_with(value_fn, array_fn)?);
341+
}
342+
Some(math_expr.eval(&eval_args))
329343
}
330344
}
331345
}
@@ -477,11 +491,9 @@ impl Display for Expression {
477491
Self::Binary { left, operation, right } => {
478492
write!(f, "({left} {operation} {right})")
479493
}
480-
Self::Log2Ceil { value } => {
481-
write!(f, "log2_ceil({value})")
482-
}
483-
Self::NextMultipleOf { value, multiple } => {
484-
write!(f, "next_multiple_of({value}, {multiple})")
494+
Self::MathExpr(math_expr, args) => {
495+
let args_str = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>().join(", ");
496+
write!(f, "{}({})", math_expr, args_str)
485497
}
486498
}
487499
}
@@ -693,11 +705,9 @@ impl Display for ConstExpression {
693705
Self::Binary { left, operation, right } => {
694706
write!(f, "({left} {operation} {right})")
695707
}
696-
Self::Log2Ceil { value } => {
697-
write!(f, "log2_ceil({value})")
698-
}
699-
Self::NextMultipleOf { value, multiple } => {
700-
write!(f, "next_multiple_of({value}, {multiple})")
708+
Self::MathExpr(math_expr, args) => {
709+
let args_str = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>().join(", ");
710+
write!(f, "{}({})", math_expr, args_str)
701711
}
702712
}
703713
}

crates/lean_compiler/src/parser/grammar.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
use pest::Parser;
2+
pub use pest::iterators::Pair;
23
use pest_derive::Parser;
34

45
/// Main parser struct generated by Pest.
56
#[derive(Parser)]
67
#[grammar = "grammar.pest"]
78
pub struct LangParser;
89

9-
pub use pest::iterators::Pair;
10-
1110
/// Type alias for a parsed grammar rule.
1211
pub type ParsePair<'i> = Pair<'i, Rule>;
1312

crates/lean_compiler/src/parser/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ mod grammar;
55
mod lexer;
66
mod parsers;
77

8-
pub use error::ParseError;
98
pub use grammar::parse_source;
109
pub use parsers::program::ProgramParser;
1110
pub use parsers::{Parse, ParseContext};
1211

1312
use crate::lang::Program;
13+
use crate::parser::error::ParseError;
1414
use std::collections::BTreeMap;
1515

1616
/// Main entry point for parsing Lean programs.
@@ -23,5 +23,5 @@ pub fn parse_program(input: &str) -> Result<(Program, BTreeMap<usize, String>),
2323

2424
// Parse into semantic structures
2525
let mut ctx = ParseContext::new();
26-
ProgramParser::parse(program_pair, &mut ctx)
26+
ProgramParser.parse(program_pair, &mut ctx)
2727
}

0 commit comments

Comments
 (0)