diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index 90489cbb..558ff452 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -2,8 +2,8 @@ use crate::{ Counter, F, ir::HighLevelOperation, lang::{ - Boolean, ConstExpression, ConstMallocLabel, Expression, Function, Line, Program, - SimpleExpr, Var, + AssumeBoolean, Boolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue, + Expression, Function, Line, Program, SimpleExpr, Var, }, precompiles::Precompile, }; @@ -97,6 +97,12 @@ pub enum SimpleLine { then_branch: Vec, else_branch: Vec, }, + TestZero { + // Test that the result of the given operation is zero + operation: HighLevelOperation, + arg0: SimpleExpr, + arg1: SimpleExpr, + }, FunctionCall { function_name: String, args: Vec, @@ -353,26 +359,68 @@ fn simplify_lines( then_branch, else_branch, } => { - // Transform if a == b then X else Y into if a != b then Y else X + let (condition_simplified, then_branch, else_branch) = match condition { + Condition::Comparison(condition) => { + // Transform if a == b then X else Y into if a != b then Y else X + + let (left, right, then_branch, else_branch) = match condition { + Boolean::Equal { left, right } => { + (left, right, else_branch, then_branch) + } // switched + Boolean::Different { left, right } => { + (left, right, then_branch, else_branch) + } + }; + + let left_simplified = + simplify_expr(left, &mut res, counters, array_manager, const_malloc); + let right_simplified = + simplify_expr(right, &mut res, counters, array_manager, const_malloc); + + let diff_var = format!("@diff_{}", counters.aux_vars); + counters.aux_vars += 1; + res.push(SimpleLine::Assignment { + var: diff_var.clone().into(), + operation: HighLevelOperation::Sub, + arg0: left_simplified, + arg1: right_simplified, + }); + (diff_var.into(), then_branch, else_branch) + } + Condition::Expression(condition, assume_boolean) => { + let condition_simplified = simplify_expr( + condition, + &mut res, + counters, + array_manager, + const_malloc, + ); - let (left, right, then_branch, else_branch) = match condition { - Boolean::Equal { left, right } => (left, right, else_branch, then_branch), // switched - Boolean::Different { left, right } => (left, right, then_branch, else_branch), - }; + match assume_boolean { + AssumeBoolean::AssumeBoolean => {} + AssumeBoolean::DoNotAssumeBoolean => { + // Check condition_simplified is boolean + let one_minus_condition_var = format!("@aux_{}", counters.aux_vars); + counters.aux_vars += 1; + res.push(SimpleLine::Assignment { + var: one_minus_condition_var.clone().into(), + operation: HighLevelOperation::Sub, + arg0: SimpleExpr::Constant(ConstExpression::Value( + ConstantValue::Scalar(1), + )), + arg1: condition_simplified.clone(), + }); + res.push(SimpleLine::TestZero { + operation: HighLevelOperation::Mul, + arg0: condition_simplified.clone(), + arg1: one_minus_condition_var.into(), + }); + } + } - let left_simplified = - simplify_expr(left, &mut res, counters, array_manager, const_malloc); - let right_simplified = - simplify_expr(right, &mut res, counters, array_manager, const_malloc); - - let diff_var = format!("@diff_{}", counters.aux_vars); - counters.aux_vars += 1; - res.push(SimpleLine::Assignment { - var: diff_var.clone().into(), - operation: HighLevelOperation::Sub, - arg0: left_simplified, - arg1: right_simplified, - }); + (condition_simplified, then_branch, else_branch) + } + }; let forbidden_vars_before = const_malloc.forbidden_vars.clone(); @@ -417,7 +465,7 @@ fn simplify_lines( .collect(); res.push(SimpleLine::IfNotZero { - condition: diff_var.into(), + condition: condition_simplified, then_branch: then_branch_simplified, else_branch: else_branch_simplified, }); @@ -787,12 +835,20 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { } }; - let on_new_condition = - |condition: &Boolean, internal_vars: &BTreeSet, external_vars: &mut BTreeSet| { - let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition; - on_new_expr(left, internal_vars, external_vars); - on_new_expr(right, internal_vars, external_vars); - }; + let on_new_condition = |condition: &Condition, + internal_vars: &BTreeSet, + external_vars: &mut BTreeSet| { + match condition { + Condition::Comparison(Boolean::Equal { left, right }) + | Condition::Comparison(Boolean::Different { left, right }) => { + on_new_expr(left, internal_vars, external_vars); + on_new_expr(right, internal_vars, external_vars); + } + Condition::Expression(expr, _assume_boolean) => { + on_new_expr(expr, internal_vars, external_vars); + } + } + }; for line in lines { match line { @@ -839,7 +895,11 @@ pub fn find_variable_usage(lines: &[Line]) -> (BTreeSet, BTreeSet) { internal_vars.extend(return_data.iter().cloned()); } Line::Assert(condition) => { - on_new_condition(condition, &internal_vars, &mut external_vars); + on_new_condition( + &Condition::Comparison(condition.clone()), + &internal_vars, + &mut external_vars, + ); } Line::FunctionRet { return_data } => { for ret in return_data { @@ -944,12 +1004,17 @@ pub fn inline_lines( res: &[Var], inlining_count: usize, ) { - let inline_condition = |condition: &mut Boolean| { - let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = condition; + let inline_comparison = |comparison: &mut Boolean| { + let (Boolean::Equal { left, right } | Boolean::Different { left, right }) = comparison; inline_expr(left, args, inlining_count); inline_expr(right, args, inlining_count); }; + let inline_condition = |condition: &mut Condition| match condition { + Condition::Comparison(comparison) => inline_comparison(comparison), + Condition::Expression(expr, _assume_boolean) => inline_expr(expr, args, inlining_count), + }; + let inline_internal_var = |var: &mut Var| { assert!( !args.contains_key(var), @@ -994,7 +1059,7 @@ pub fn inline_lines( } } Line::Assert(condition) => { - inline_condition(condition); + inline_comparison(condition); } Line::FunctionRet { return_data } => { assert_eq!(return_data.len(), res.len()); @@ -1368,24 +1433,40 @@ fn replace_vars_for_unroll( ); } Line::IfCondition { - condition: Boolean::Equal { left, right } | Boolean::Different { left, right }, + condition, then_branch, else_branch, } => { - replace_vars_for_unroll_in_expr( - left, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); - replace_vars_for_unroll_in_expr( - right, - iterator, - unroll_index, - iterator_value, - internal_vars, - ); + match condition { + Condition::Comparison( + Boolean::Equal { left, right } | Boolean::Different { left, right }, + ) => { + replace_vars_for_unroll_in_expr( + left, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + replace_vars_for_unroll_in_expr( + right, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + Condition::Expression(expr, _assume_bool) => { + replace_vars_for_unroll_in_expr( + expr, + iterator, + unroll_index, + iterator_value, + internal_vars, + ); + } + } + replace_vars_for_unroll( then_branch, iterator, @@ -1972,10 +2053,14 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap) { else_branch, } => { match condition { - Boolean::Equal { left, right } | Boolean::Different { left, right } => { + Condition::Comparison(Boolean::Equal { left, right }) + | Condition::Comparison(Boolean::Different { left, right }) => { replace_vars_by_const_in_expr(left, map); replace_vars_by_const_in_expr(right, map); } + Condition::Expression(expr, _assume_boolean) => { + replace_vars_by_const_in_expr(expr, map); + } } replace_vars_by_const_in_lines(then_branch, map); replace_vars_by_const_in_lines(else_branch, map); @@ -2116,6 +2201,13 @@ impl SimpleLine { Self::RawAccess { res, index, shift } => { format!("memory[{index} + {shift}] = {res}") } + Self::TestZero { + operation, + arg0, + arg1, + } => { + format!("0 = {arg0} {operation} {arg1}") + } Self::IfNotZero { condition, then_branch, diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index c040ec78..92818757 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -184,6 +184,21 @@ fn compile_lines( } } + SimpleLine::TestZero { + operation, + arg0, + arg1, + } => { + instructions.push(IntermediateInstruction::computation( + *operation, + IntermediateValue::from_simple_expr(arg0, compiler), + IntermediateValue::from_simple_expr(arg1, compiler), + IntermediateValue::Constant(0.into()), + )); + + mark_vars_as_declared(&[arg0, arg1], declared_vars); + } + SimpleLine::Match { value, arms } => { let match_index = compiler.match_blocks.len(); let end_label = Label::match_end(match_index); @@ -768,6 +783,7 @@ fn find_internal_vars(lines: &[SimpleLine]) -> BTreeSet { internal_vars.insert(var.clone()); } } + SimpleLine::TestZero { .. } => {} SimpleLine::HintMAlloc { var, .. } | SimpleLine::ConstMalloc { var, .. } | SimpleLine::DecomposeBits { var, .. } diff --git a/crates/lean_compiler/src/grammar.pest b/crates/lean_compiler/src/grammar.pest index e0b5ddd4..d6fc7d1a 100644 --- a/crates/lean_compiler/src/grammar.pest +++ b/crates/lean_compiler/src/grammar.pest @@ -40,9 +40,9 @@ array_assign = { identifier ~ "[" ~ expression ~ "]" ~ "=" ~ expression ~ ";" } if_statement = { "if" ~ condition ~ "{" ~ statement* ~ "}" ~ else_clause? } -condition = {condition_eq | condition_diff} -condition_eq = { expression ~ "==" ~ expression } -condition_diff = { expression ~ "!=" ~ expression } +condition = { expression | assumed_bool_expr } + +assumed_bool_expr = { "!!assume_bool" ~ "(" ~ expression ~ ")" } else_clause = { "else" ~ "{" ~ statement* ~ "}" } @@ -58,12 +58,14 @@ function_call = { function_res? ~ identifier ~ "(" ~ tuple_expression? ~ ")" ~ " function_res = { var_list ~ "=" } var_list = { identifier ~ ("," ~ identifier)* } -assert_eq_statement = { "assert" ~ expression ~ "==" ~ expression ~ ";" } -assert_not_eq_statement = { "assert" ~ expression ~ "!=" ~ expression ~ ";" } +assert_eq_statement = { "assert" ~ add_expr ~ "==" ~ add_expr ~ ";" } +assert_not_eq_statement = { "assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" } // Expressions tuple_expression = { expression ~ ("," ~ expression)* } -expression = { add_expr } +expression = { neq_expr } +neq_expr = { eq_expr ~ ("!=" ~ eq_expr)* } +eq_expr = { add_expr ~ ("==" ~ add_expr)* } add_expr = { sub_expr ~ ("+" ~ sub_expr)* } sub_expr = { mul_expr ~ ("-" ~ mul_expr)* } mul_expr = { mod_expr ~ ("*" ~ mod_expr)* } @@ -85,4 +87,4 @@ constant_value = { number | "public_input_start" } // Lexical elements identifier = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* } -number = @{ ASCII_DIGIT+ } \ No newline at end of file +number = @{ ASCII_DIGIT+ } diff --git a/crates/lean_compiler/src/ir/instruction.rs b/crates/lean_compiler/src/ir/instruction.rs index 97cb1d59..40e73770 100644 --- a/crates/lean_compiler/src/ir/instruction.rs +++ b/crates/lean_compiler/src/ir/instruction.rs @@ -120,7 +120,10 @@ impl IntermediateInstruction { arg_c, res: arg_a, }, - HighLevelOperation::Exp | HighLevelOperation::Mod => unreachable!(), + HighLevelOperation::Exp + | HighLevelOperation::Mod + | HighLevelOperation::Equal + | HighLevelOperation::NotEqual => unreachable!(), } } diff --git a/crates/lean_compiler/src/ir/operation.rs b/crates/lean_compiler/src/ir/operation.rs index 39325468..d829e5e0 100644 --- a/crates/lean_compiler/src/ir/operation.rs +++ b/crates/lean_compiler/src/ir/operation.rs @@ -23,11 +23,29 @@ pub enum HighLevelOperation { Exp, /// Modulo operation (only for constant expressions). Mod, + /// Equality comparison + Equal, + /// Non-equality comparison + NotEqual, } impl HighLevelOperation { pub fn eval(&self, a: F, b: F) -> F { match self { + Self::Equal => { + if a == b { + F::ONE + } else { + F::ZERO + } + } + Self::NotEqual => { + if a != b { + F::ONE + } else { + F::ZERO + } + } Self::Add => a + b, Self::Mul => a * b, Self::Sub => a - b, @@ -41,6 +59,8 @@ impl HighLevelOperation { impl Display for HighLevelOperation { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { + Self::Equal => write!(f, "=="), + Self::NotEqual => write!(f, "!="), Self::Add => write!(f, "+"), Self::Mul => write!(f, "*"), Self::Sub => write!(f, "-"), diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index 0c1a989e..7eea1ea0 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -221,6 +221,30 @@ impl From for ConstExpression { } } +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum AssumeBoolean { + AssumeBoolean, + DoNotAssumeBoolean, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Condition { + Expression(Expression, AssumeBoolean), + Comparison(Boolean), +} + +impl Display for Condition { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Expression(expr, AssumeBoolean::AssumeBoolean) => { + write!(f, "!assume_bool({expr})") + } + Self::Expression(expr, AssumeBoolean::DoNotAssumeBoolean) => write!(f, "{expr}"), + Self::Comparison(cmp) => write!(f, "{cmp}"), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Expression { Value(SimpleExpr), @@ -310,7 +334,7 @@ pub enum Line { }, Assert(Boolean), IfCondition { - condition: Boolean, + condition: Condition, then_branch: Vec, else_branch: Vec, }, diff --git a/crates/lean_compiler/src/parser/parsers/expression.rs b/crates/lean_compiler/src/parser/parsers/expression.rs index e3ce0d42..39c5c70d 100644 --- a/crates/lean_compiler/src/parser/parsers/expression.rs +++ b/crates/lean_compiler/src/parser/parsers/expression.rs @@ -4,7 +4,7 @@ use crate::{ ir::HighLevelOperation, lang::Expression, parser::{ - error::{ParseResult, SemanticError}, + error::{ParseError, ParseResult, SemanticError}, grammar::{ParsePair, Rule}, }, }; @@ -19,6 +19,12 @@ impl Parse for ExpressionParser { let inner = next_inner_pair(&mut pair.into_inner(), "expression body")?; Self::parse(inner, ctx) } + Rule::neq_expr => { + BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::NotEqual) + } + Rule::eq_expr => { + BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Equal) + } Rule::add_expr => { BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Add) } @@ -38,7 +44,9 @@ impl Parse for ExpressionParser { BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Exp) } Rule::primary => PrimaryExpressionParser::parse(pair, ctx), - _ => Err(SemanticError::new("Invalid expression").into()), + other_rule => Err(ParseError::SemanticError(SemanticError::new(format!( + "ExpressionParser: Unexpected rule {other_rule:?}" + )))), } } } diff --git a/crates/lean_compiler/src/parser/parsers/statement.rs b/crates/lean_compiler/src/parser/parsers/statement.rs index 0f2b90e8..2d4e0fa9 100644 --- a/crates/lean_compiler/src/parser/parsers/statement.rs +++ b/crates/lean_compiler/src/parser/parsers/statement.rs @@ -3,7 +3,8 @@ use super::function::{FunctionCallParser, TupleExpressionParser}; use super::literal::ConstExprParser; use super::{Parse, ParseContext, next_inner_pair}; use crate::{ - lang::{Boolean, Line}, + ir::HighLevelOperation, + lang::{AssumeBoolean, Boolean, Condition, Expression, Line}, parser::{ error::{ParseResult, SemanticError}, grammar::{ParsePair, Rule}, @@ -123,6 +124,47 @@ impl IfStatementParser { } } +/// Parser for conditions. +pub struct ConditionParser; + +impl Parse for ConditionParser { + fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { + let inner_pair = next_inner_pair(&mut pair.into_inner(), "inner expression")?; + if inner_pair.as_rule() == Rule::assumed_bool_expr { + ExpressionParser::parse( + next_inner_pair(&mut inner_pair.into_inner(), "inner expression")?, + ctx, + ) + .map(|e| Condition::Expression(e, AssumeBoolean::AssumeBoolean)) + } else { + let expr_result = ExpressionParser::parse(inner_pair, ctx); + match expr_result { + Err(e) => Err(e), + Ok(Expression::Binary { + left, + operation: HighLevelOperation::Equal, + right, + }) => Ok(Condition::Comparison(Boolean::Equal { + left: *left, + right: *right, + })), + Ok(Expression::Binary { + left, + operation: HighLevelOperation::NotEqual, + right, + }) => Ok(Condition::Comparison(Boolean::Different { + left: *left, + right: *right, + })), + Ok(expr) => Ok(Condition::Expression( + expr, + AssumeBoolean::DoNotAssumeBoolean, + )), + } + } + } +} + /// Parser for for-loop statements. pub struct ForStatementParser; @@ -251,25 +293,6 @@ impl Parse for ReturnStatementParser { } } -/// Parser for boolean conditions used in if statements and assertions. -pub struct ConditionParser; - -impl Parse for ConditionParser { - fn parse(pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let inner = next_inner_pair(&mut pair.into_inner(), "condition")?; - let mut parts = inner.clone().into_inner(); - - let left = ExpressionParser::parse(next_inner_pair(&mut parts, "left side")?, ctx)?; - let right = ExpressionParser::parse(next_inner_pair(&mut parts, "right side")?, ctx)?; - - match inner.as_rule() { - Rule::condition_eq => Ok(Boolean::Equal { left, right }), - Rule::condition_diff => Ok(Boolean::Different { left, right }), - _ => Err(SemanticError::new("Invalid condition type").into()), - } - } -} - /// Parser for equality assertions. pub struct AssertEqParser; diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 10a648ca..86e56670 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -278,7 +278,7 @@ fn test_mini_program_2() { for j in i..10 { for k in j..10 { sum, prod = compute_sum_and_product(i, j, k); - if sum == 10 { + if (sum == 10) { print(i, j, k, prod); } } diff --git a/crates/lean_prover/tests/hash_chain.rs b/crates/lean_prover/tests/hash_chain.rs index a7634a3b..06afb886 100644 --- a/crates/lean_prover/tests/hash_chain.rs +++ b/crates/lean_prover/tests/hash_chain.rs @@ -101,7 +101,7 @@ fn benchmark_poseidon_chain() { display_logs: true, }); - println!("VM proof time: {:?}", vm_time); + println!("VM proof time: {vm_time:?}"); println!("Raw Poseidon proof time: {:?}", raw_proof.prover_time); println!( diff --git a/crates/packed_pcs/src/lib.rs b/crates/packed_pcs/src/lib.rs index b09eed24..5e54b182 100644 --- a/crates/packed_pcs/src/lib.rs +++ b/crates/packed_pcs/src/lib.rs @@ -99,10 +99,7 @@ fn split_in_chunks( if let Some(log_public) = dims.log_public_data_size { assert!( log_public >= log_smallest_decomposition_chunk, - "poly {}: {} < {}", - poly_index, - log_public, - log_smallest_decomposition_chunk + "poly {poly_index}: {log_public} < {log_smallest_decomposition_chunk}" ); res.push(Chunk { original_poly_index: poly_index, diff --git a/crates/rec_aggregation/src/xmss_aggregate.rs b/crates/rec_aggregation/src/xmss_aggregate.rs index 7e312f73..8dbb51a6 100644 --- a/crates/rec_aggregation/src/xmss_aggregate.rs +++ b/crates/rec_aggregation/src/xmss_aggregate.rs @@ -43,7 +43,7 @@ pub fn run_xmss_benchmark(n_xmss: usize) -> XmssBenchStats { bitield = public_input_start + (2 + N_PUBLIC_KEYS) * 8; signatures_start = private_input_start / 8; for i in 0..N_PUBLIC_KEYS { - if bitield[i] == 1 { + if !!assume_bool(bitield[i]) { xmss_public_key = all_public_keys + i; sig_index = counter_hint();