Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 139 additions & 47 deletions crates/lean_compiler/src/a_simplify_lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::{
Counter, F,
ir::HighLevelOperation,
lang::{
Boolean, ConstExpression, ConstMallocLabel, Expression, Function, Line, Program,
SimpleExpr, Var,
AssumeBoolean, Boolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue,
Expression, Function, Line, Program, SimpleExpr, Var,
},
precompiles::Precompile,
};
Expand Down Expand Up @@ -97,6 +97,12 @@ pub enum SimpleLine {
then_branch: Vec<Self>,
else_branch: Vec<Self>,
},
TestZero {
// Test that the result of the given operation is zero
operation: HighLevelOperation,
arg0: SimpleExpr,
arg1: SimpleExpr,
},
FunctionCall {
function_name: String,
args: Vec<SimpleExpr>,
Expand Down Expand Up @@ -353,26 +359,68 @@ fn simplify_lines(
then_branch,
else_branch,
} => {
// Transform if a == b then X else Y into if a != b then Y else X
let (condition_simplified, then_branch, else_branch) = match condition {
Condition::Comparison(condition) => {
// Transform if a == b then X else Y into if a != b then Y else X

let (left, right, then_branch, else_branch) = match condition {
Boolean::Equal { left, right } => {
(left, right, else_branch, then_branch)
} // switched
Boolean::Different { left, right } => {
(left, right, then_branch, else_branch)
}
};

let left_simplified =
simplify_expr(left, &mut res, counters, array_manager, const_malloc);
let right_simplified =
simplify_expr(right, &mut res, counters, array_manager, const_malloc);

let diff_var = format!("@diff_{}", counters.aux_vars);
counters.aux_vars += 1;
res.push(SimpleLine::Assignment {
var: diff_var.clone().into(),
operation: HighLevelOperation::Sub,
arg0: left_simplified,
arg1: right_simplified,
});
(diff_var.into(), then_branch, else_branch)
}
Condition::Expression(condition, assume_boolean) => {
let condition_simplified = simplify_expr(
condition,
&mut res,
counters,
array_manager,
const_malloc,
);

let (left, right, then_branch, else_branch) = match condition {
Boolean::Equal { left, right } => (left, right, else_branch, then_branch), // switched
Boolean::Different { left, right } => (left, right, then_branch, else_branch),
};
match assume_boolean {
AssumeBoolean::AssumeBoolean => {}
AssumeBoolean::DoNotAssumeBoolean => {
// Check condition_simplified is boolean
let one_minus_condition_var = format!("@aux_{}", counters.aux_vars);
counters.aux_vars += 1;
res.push(SimpleLine::Assignment {
var: one_minus_condition_var.clone().into(),
operation: HighLevelOperation::Sub,
arg0: SimpleExpr::Constant(ConstExpression::Value(
ConstantValue::Scalar(1),
)),
arg1: condition_simplified.clone(),
});
res.push(SimpleLine::TestZero {
operation: HighLevelOperation::Mul,
arg0: condition_simplified.clone(),
arg1: one_minus_condition_var.into(),
});
}
}

let left_simplified =
simplify_expr(left, &mut res, counters, array_manager, const_malloc);
let right_simplified =
simplify_expr(right, &mut res, counters, array_manager, const_malloc);

let diff_var = format!("@diff_{}", counters.aux_vars);
counters.aux_vars += 1;
res.push(SimpleLine::Assignment {
var: diff_var.clone().into(),
operation: HighLevelOperation::Sub,
arg0: left_simplified,
arg1: right_simplified,
});
(condition_simplified, then_branch, else_branch)
}
};

let forbidden_vars_before = const_malloc.forbidden_vars.clone();

Expand Down Expand Up @@ -417,7 +465,7 @@ fn simplify_lines(
.collect();

res.push(SimpleLine::IfNotZero {
condition: diff_var.into(),
condition: condition_simplified,
then_branch: then_branch_simplified,
else_branch: else_branch_simplified,
});
Expand Down Expand Up @@ -787,12 +835,20 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet<Var>, BTreeSet<Var>) {
}
};

let on_new_condition =
|condition: &Boolean, internal_vars: &BTreeSet<Var>, external_vars: &mut BTreeSet<Var>| {
let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition;
on_new_expr(left, internal_vars, external_vars);
on_new_expr(right, internal_vars, external_vars);
};
let on_new_condition = |condition: &Condition,
internal_vars: &BTreeSet<Var>,
external_vars: &mut BTreeSet<Var>| {
match condition {
Condition::Comparison(Boolean::Equal { left, right })
| Condition::Comparison(Boolean::Different { left, right }) => {
on_new_expr(left, internal_vars, external_vars);
on_new_expr(right, internal_vars, external_vars);
}
Condition::Expression(expr, _assume_boolean) => {
on_new_expr(expr, internal_vars, external_vars);
}
}
};

for line in lines {
match line {
Expand Down Expand Up @@ -839,7 +895,11 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet<Var>, BTreeSet<Var>) {
internal_vars.extend(return_data.iter().cloned());
}
Line::Assert(condition) => {
on_new_condition(condition, &internal_vars, &mut external_vars);
on_new_condition(
&Condition::Comparison(condition.clone()),
&internal_vars,
&mut external_vars,
);
}
Line::FunctionRet { return_data } => {
for ret in return_data {
Expand Down Expand Up @@ -944,12 +1004,17 @@ pub fn inline_lines(
res: &[Var],
inlining_count: usize,
) {
let inline_condition = |condition: &mut Boolean| {
let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition;
let inline_comparison = |comparison: &mut Boolean| {
let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = comparison;
inline_expr(left, args, inlining_count);
inline_expr(right, args, inlining_count);
};

let inline_condition = |condition: &mut Condition| match condition {
Condition::Comparison(comparison) => inline_comparison(comparison),
Condition::Expression(expr, _assume_boolean) => inline_expr(expr, args, inlining_count),
};

let inline_internal_var = |var: &mut Var| {
assert!(
!args.contains_key(var),
Expand Down Expand Up @@ -994,7 +1059,7 @@ pub fn inline_lines(
}
}
Line::Assert(condition) => {
inline_condition(condition);
inline_comparison(condition);
}
Line::FunctionRet { return_data } => {
assert_eq!(return_data.len(), res.len());
Expand Down Expand Up @@ -1368,24 +1433,40 @@ fn replace_vars_for_unroll(
);
}
Line::IfCondition {
condition: Boolean::Equal { left, right } | Boolean::Different { left, right },
condition,
then_branch,
else_branch,
} => {
replace_vars_for_unroll_in_expr(
left,
iterator,
unroll_index,
iterator_value,
internal_vars,
);
replace_vars_for_unroll_in_expr(
right,
iterator,
unroll_index,
iterator_value,
internal_vars,
);
match condition {
Condition::Comparison(
Boolean::Equal { left, right } | Boolean::Different { left, right },
) => {
replace_vars_for_unroll_in_expr(
left,
iterator,
unroll_index,
iterator_value,
internal_vars,
);
replace_vars_for_unroll_in_expr(
right,
iterator,
unroll_index,
iterator_value,
internal_vars,
);
}
Condition::Expression(expr, _assume_bool) => {
replace_vars_for_unroll_in_expr(
expr,
iterator,
unroll_index,
iterator_value,
internal_vars,
);
}
}

replace_vars_for_unroll(
then_branch,
iterator,
Expand Down Expand Up @@ -1972,10 +2053,14 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap<Var, F>) {
else_branch,
} => {
match condition {
Boolean::Equal { left, right } | Boolean::Different { left, right } => {
Condition::Comparison(Boolean::Equal { left, right })
| Condition::Comparison(Boolean::Different { left, right }) => {
replace_vars_by_const_in_expr(left, map);
replace_vars_by_const_in_expr(right, map);
}
Condition::Expression(expr, _assume_boolean) => {
replace_vars_by_const_in_expr(expr, map);
}
}
replace_vars_by_const_in_lines(then_branch, map);
replace_vars_by_const_in_lines(else_branch, map);
Expand Down Expand Up @@ -2116,6 +2201,13 @@ impl SimpleLine {
Self::RawAccess { res, index, shift } => {
format!("memory[{index} + {shift}] = {res}")
}
Self::TestZero {
operation,
arg0,
arg1,
} => {
format!("0 = {arg0} {operation} {arg1}")
}
Self::IfNotZero {
condition,
then_branch,
Expand Down
16 changes: 16 additions & 0 deletions crates/lean_compiler/src/b_compile_intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,21 @@ fn compile_lines(
}
}

SimpleLine::TestZero {
operation,
arg0,
arg1,
} => {
instructions.push(IntermediateInstruction::computation(
*operation,
IntermediateValue::from_simple_expr(arg0, compiler),
IntermediateValue::from_simple_expr(arg1, compiler),
IntermediateValue::Constant(0.into()),
));

mark_vars_as_declared(&[arg0, arg1], declared_vars);
}

SimpleLine::Match { value, arms } => {
let match_index = compiler.match_blocks.len();
let end_label = Label::match_end(match_index);
Expand Down Expand Up @@ -768,6 +783,7 @@ fn find_internal_vars(lines: &[SimpleLine]) -> BTreeSet<Var> {
internal_vars.insert(var.clone());
}
}
SimpleLine::TestZero { .. } => {}
SimpleLine::HintMAlloc { var, .. }
| SimpleLine::ConstMalloc { var, .. }
| SimpleLine::DecomposeBits { var, .. }
Expand Down
16 changes: 9 additions & 7 deletions crates/lean_compiler/src/grammar.pest
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ array_assign = { identifier ~ "[" ~ expression ~ "]" ~ "=" ~ expression ~ ";" }

if_statement = { "if" ~ condition ~ "{" ~ statement* ~ "}" ~ else_clause? }

condition = {condition_eq | condition_diff}
condition_eq = { expression ~ "==" ~ expression }
condition_diff = { expression ~ "!=" ~ expression }
condition = { expression | assumed_bool_expr }

assumed_bool_expr = { "!!assume_bool" ~ "(" ~ expression ~ ")" }

else_clause = { "else" ~ "{" ~ statement* ~ "}" }

Expand All @@ -58,12 +58,14 @@ function_call = { function_res? ~ identifier ~ "(" ~ tuple_expression? ~ ")" ~ "
function_res = { var_list ~ "=" }
var_list = { identifier ~ ("," ~ identifier)* }

assert_eq_statement = { "assert" ~ expression ~ "==" ~ expression ~ ";" }
assert_not_eq_statement = { "assert" ~ expression ~ "!=" ~ expression ~ ";" }
assert_eq_statement = { "assert" ~ add_expr ~ "==" ~ add_expr ~ ";" }
assert_not_eq_statement = { "assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" }

// Expressions
tuple_expression = { expression ~ ("," ~ expression)* }
expression = { add_expr }
expression = { neq_expr }
neq_expr = { eq_expr ~ ("!=" ~ eq_expr)* }
eq_expr = { add_expr ~ ("==" ~ add_expr)* }
add_expr = { sub_expr ~ ("+" ~ sub_expr)* }
sub_expr = { mul_expr ~ ("-" ~ mul_expr)* }
mul_expr = { mod_expr ~ ("*" ~ mod_expr)* }
Expand All @@ -85,4 +87,4 @@ constant_value = { number | "public_input_start" }

// Lexical elements
identifier = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* }
number = @{ ASCII_DIGIT+ }
number = @{ ASCII_DIGIT+ }
5 changes: 4 additions & 1 deletion crates/lean_compiler/src/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ impl IntermediateInstruction {
arg_c,
res: arg_a,
},
HighLevelOperation::Exp | HighLevelOperation::Mod => unreachable!(),
HighLevelOperation::Exp
| HighLevelOperation::Mod
| HighLevelOperation::Equal
| HighLevelOperation::NotEqual => unreachable!(),
}
}

Expand Down
20 changes: 20 additions & 0 deletions crates/lean_compiler/src/ir/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,29 @@ pub enum HighLevelOperation {
Exp,
/// Modulo operation (only for constant expressions).
Mod,
/// Equality comparison
Equal,
/// Non-equality comparison
NotEqual,
}

impl HighLevelOperation {
pub fn eval(&self, a: F, b: F) -> F {
match self {
Self::Equal => {
if a == b {
F::ONE
} else {
F::ZERO
}
}
Self::NotEqual => {
if a != b {
F::ONE
} else {
F::ZERO
}
}
Self::Add => a + b,
Self::Mul => a * b,
Self::Sub => a - b,
Expand All @@ -41,6 +59,8 @@ impl HighLevelOperation {
impl Display for HighLevelOperation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Equal => write!(f, "=="),
Self::NotEqual => write!(f, "!="),
Self::Add => write!(f, "+"),
Self::Mul => write!(f, "*"),
Self::Sub => write!(f, "-"),
Expand Down
Loading