Skip to content

Commit 02a6f04

Browse files
authored
simplify conditions (#123)
Co-authored-by: Tom Wambsgans <[email protected]>
1 parent 88f3464 commit 02a6f04

File tree

9 files changed

+85
-213
lines changed

9 files changed

+85
-213
lines changed

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use crate::{
22
Counter, F,
33
ir::HighLevelOperation,
44
lang::{
5-
AssignmentTarget, AssumeBoolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue, Context,
6-
Expression, Function, Line, Program, Scope, SimpleExpr, Var,
5+
AssignmentTarget, Condition, ConstExpression, ConstMallocLabel, Context, Expression, Function, Line, MathExpr,
6+
Program, Scope, SimpleExpr, Var,
77
},
88
parser::ConstArrayValue,
99
};
@@ -529,7 +529,7 @@ fn check_boolean_scoping(boolean: &BooleanExpr<Expression>, ctx: &Context) {
529529

530530
fn check_condition_scoping(condition: &Condition, ctx: &Context) {
531531
match condition {
532-
Condition::Expression(expr, _) => {
532+
Condition::AssumeBoolean(expr) => {
533533
check_expr_scoping(expr, ctx);
534534
}
535535
Condition::Comparison(boolean) => {
@@ -723,11 +723,10 @@ fn simplify_lines(
723723
if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) =
724724
(&left, &right)
725725
{
726-
let result = ConstExpression::Binary {
727-
left: Box::new(left_cst.clone()),
728-
operation: *operation,
729-
right: Box::new(right_cst.clone()),
730-
}
726+
let result = ConstExpression::MathExpr(
727+
MathExpr::Binary(*operation),
728+
vec![left_cst.clone(), right_cst.clone()],
729+
)
731730
.try_naive_simplification();
732731
res.push(SimpleLine::equality(var.clone(), SimpleExpr::Constant(result)));
733732
} else {
@@ -855,28 +854,8 @@ fn simplify_lines(
855854
});
856855
(diff_var.into(), then_branch, else_branch)
857856
}
858-
Condition::Expression(condition, assume_boolean) => {
857+
Condition::AssumeBoolean(condition) => {
859858
let condition_simplified = simplify_expr(ctx, state, const_malloc, condition, &mut res);
860-
861-
match assume_boolean {
862-
AssumeBoolean::AssumeBoolean => {}
863-
AssumeBoolean::DoNotAssumeBoolean => {
864-
// Check condition_simplified is boolean
865-
let one_minus_condition_var = state.counters.aux_var();
866-
res.push(SimpleLine::Assignment {
867-
var: one_minus_condition_var.clone().into(),
868-
operation: HighLevelOperation::Sub,
869-
arg0: SimpleExpr::Constant(ConstExpression::Value(ConstantValue::Scalar(1))),
870-
arg1: condition_simplified.clone(),
871-
});
872-
res.push(SimpleLine::AssertZero {
873-
operation: HighLevelOperation::Mul,
874-
arg0: condition_simplified.clone(),
875-
arg1: one_minus_condition_var.into(),
876-
});
877-
}
878-
}
879-
880859
(condition_simplified, then_branch, else_branch)
881860
}
882861
};
@@ -1194,11 +1173,10 @@ fn simplify_expr(
11941173
let right_var = simplify_expr(ctx, state, const_malloc, right, lines);
11951174

11961175
if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = (&left_var, &right_var) {
1197-
return SimpleExpr::Constant(ConstExpression::Binary {
1198-
left: Box::new(left_cst.clone()),
1199-
operation: *operation,
1200-
right: Box::new(right_cst.clone()),
1201-
});
1176+
return SimpleExpr::Constant(ConstExpression::MathExpr(
1177+
MathExpr::Binary(*operation),
1178+
vec![left_cst.clone(), right_cst.clone()],
1179+
));
12021180
}
12031181

12041182
let aux_var = state.counters.aux_var();
@@ -1275,7 +1253,7 @@ pub fn find_variable_usage(
12751253
on_new_expr(&comp.left, internal_vars, external_vars);
12761254
on_new_expr(&comp.right, internal_vars, external_vars);
12771255
}
1278-
Condition::Expression(expr, _assume_boolean) => {
1256+
Condition::AssumeBoolean(expr) => {
12791257
on_new_expr(expr, internal_vars, external_vars);
12801258
}
12811259
};
@@ -1444,7 +1422,7 @@ fn inline_lines(
14441422

14451423
let inline_condition = |condition: &mut Condition| match condition {
14461424
Condition::Comparison(comparison) => inline_comparison(comparison),
1447-
Condition::Expression(expr, _assume_boolean) => inline_expr(expr, args, inlining_count),
1425+
Condition::AssumeBoolean(expr) => inline_expr(expr, args, inlining_count),
14481426
};
14491427

14501428
let inline_internal_var = |var: &mut Var| {
@@ -1883,7 +1861,7 @@ fn replace_vars_for_unroll(
18831861
internal_vars,
18841862
);
18851863
}
1886-
Condition::Expression(expr, _assume_bool) => {
1864+
Condition::AssumeBoolean(expr) => {
18871865
replace_vars_for_unroll_in_expr(expr, iterator, unroll_index, iterator_value, internal_vars);
18881866
}
18891867
}
@@ -2142,9 +2120,9 @@ fn extract_inlined_calls_from_condition(
21422120
inlined_var_counter: &mut Counter,
21432121
) -> (Condition, Vec<Line>) {
21442122
match condition {
2145-
Condition::Expression(expr, assume_boolean) => {
2123+
Condition::AssumeBoolean(expr) => {
21462124
let (expr, expr_lines) = extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter);
2147-
(Condition::Expression(expr, *assume_boolean), expr_lines)
2125+
(Condition::AssumeBoolean(expr), expr_lines)
21482126
}
21492127
Condition::Comparison(boolean) => {
21502128
let (boolean, boolean_lines) =
@@ -2719,7 +2697,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap<Var, F>) {
27192697
replace_vars_by_const_in_expr(&mut cond.left, map);
27202698
replace_vars_by_const_in_expr(&mut cond.right, map);
27212699
}
2722-
Condition::Expression(expr, _assume_boolean) => {
2700+
Condition::AssumeBoolean(expr) => {
27232701
replace_vars_by_const_in_expr(expr, map);
27242702
}
27252703
}

crates/lean_compiler/src/b_compile_intermediate.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ impl Compiler {
5151
VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset } => {
5252
for scope in self.stack_frame_layout.scopes.iter().rev() {
5353
if let Some(base) = scope.const_mallocs.get(malloc_label) {
54-
return ConstExpression::Binary {
55-
left: Box::new((*base).into()),
56-
operation: HighLevelOperation::Add,
57-
right: Box::new((*offset).clone()),
58-
};
54+
return ConstExpression::MathExpr(
55+
MathExpr::Binary(HighLevelOperation::Add),
56+
vec![(*base).into(), offset.clone()],
57+
);
5958
}
6059
}
6160
panic!("Const malloc {malloc_label} not in scope");
@@ -518,7 +517,7 @@ fn compile_lines(
518517
});
519518
}
520519
SimpleLine::ConstMalloc { var, size, label } => {
521-
let size = size.naive_eval().unwrap().to_usize(); // TODO not very good;
520+
let size = size.naive_eval().unwrap().to_usize();
522521
if !compiler.is_in_scope(var) {
523522
let current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap();
524523
current_scope_layout

crates/lean_compiler/src/grammar.pest

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,8 @@ statement = {
2929
return_statement |
3030
break_statement |
3131
continue_statement |
32-
assert_eq_statement |
33-
assert_not_eq_statement |
34-
debug_assert_eq_statement |
35-
debug_assert_not_eq_statement |
36-
debug_assert_lt_statement |
32+
assert_statement |
33+
debug_assert_statement |
3734
assignment
3835
}
3936

@@ -51,10 +48,14 @@ assignment_target = { array_access_expr | identifier }
5148

5249
if_statement = { "if" ~ condition ~ "{" ~ statement* ~ "}" ~ else_if_clause* ~ else_clause? }
5350

54-
condition = { expression | assumed_bool_expr }
51+
condition = { assumed_bool_expr | comparison }
5552

5653
assumed_bool_expr = { "!!assume_bool" ~ "(" ~ expression ~ ")" }
5754

55+
// Comparisons (shared between conditions and assertions)
56+
comparison = { add_expr ~ comparison_op ~ add_expr }
57+
comparison_op = { "==" | "!=" | "<" }
58+
5859
else_if_clause = { "else" ~ "if" ~ condition ~ "{" ~ statement* ~ "}" }
5960

6061
else_clause = { "else" ~ "{" ~ statement* ~ "}" }
@@ -67,18 +68,12 @@ match_statement = { "match" ~ expression ~ "{" ~ match_arm* ~ "}" }
6768
match_arm = { pattern ~ "=>" ~ "{" ~ statement* ~ "}" }
6869
pattern = { constant_value }
6970

70-
assert_eq_statement = { "assert" ~ add_expr ~ "==" ~ add_expr ~ ";" }
71-
assert_not_eq_statement = { "assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" }
72-
73-
debug_assert_eq_statement = { "debug_assert" ~ add_expr ~ "==" ~ add_expr ~ ";" }
74-
debug_assert_not_eq_statement = { "debug_assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" }
75-
debug_assert_lt_statement = { "debug_assert" ~ add_expr ~ "<" ~ add_expr ~ ";" }
71+
assert_statement = { "assert" ~ comparison ~ ";" }
72+
debug_assert_statement = { "debug_assert" ~ comparison ~ ";" }
7673

7774
// Expressions
7875
tuple_expression = { expression ~ ("," ~ expression)* }
79-
expression = { neq_expr }
80-
neq_expr = { eq_expr ~ ("!=" ~ eq_expr)* }
81-
eq_expr = { add_expr ~ ("==" ~ add_expr)* }
76+
expression = { add_expr }
8277
add_expr = { sub_expr ~ ("+" ~ sub_expr)* }
8378
sub_expr = { mul_expr ~ ("-" ~ mul_expr)* }
8479
mul_expr = { mod_expr ~ ("*" ~ mod_expr)* }

crates/lean_compiler/src/ir/instruction.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,7 @@ impl IntermediateInstruction {
9494
arg_c,
9595
res: arg_a,
9696
},
97-
HighLevelOperation::Exp
98-
| HighLevelOperation::Mod
99-
| HighLevelOperation::Equal
100-
| HighLevelOperation::NotEqual => unreachable!(),
97+
HighLevelOperation::Exp | HighLevelOperation::Mod => unreachable!(),
10198
}
10299
}
103100

crates/lean_compiler/src/ir/operation.rs

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@ use multilinear_toolkit::prelude::*;
44
use std::fmt::{Display, Formatter};
55
use utils::ToUsize;
66

7-
/// High-level operations that can be performed in the IR.
8-
///
9-
/// These operations represent the semantic intent of computations
10-
/// and may be lowered to different VM operations depending on context.
117
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
128
pub enum HighLevelOperation {
139
/// Addition operation.
@@ -22,29 +18,11 @@ pub enum HighLevelOperation {
2218
Exp,
2319
/// Modulo operation (only for constant expressions).
2420
Mod,
25-
/// Equality comparison
26-
Equal,
27-
/// Non-equality comparison
28-
NotEqual,
2921
}
3022

3123
impl HighLevelOperation {
3224
pub fn eval(&self, a: F, b: F) -> F {
3325
match self {
34-
Self::Equal => {
35-
if a == b {
36-
F::ONE
37-
} else {
38-
F::ZERO
39-
}
40-
}
41-
Self::NotEqual => {
42-
if a != b {
43-
F::ONE
44-
} else {
45-
F::ZERO
46-
}
47-
}
4826
Self::Add => a + b,
4927
Self::Mul => a * b,
5028
Self::Sub => a - b,
@@ -58,8 +36,6 @@ impl HighLevelOperation {
5836
impl Display for HighLevelOperation {
5937
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
6038
match self {
61-
Self::Equal => write!(f, "=="),
62-
Self::NotEqual => write!(f, "!="),
6339
Self::Add => write!(f, "+"),
6440
Self::Mul => write!(f, "*"),
6541
Self::Sub => write!(f, "-"),

crates/lean_compiler/src/lang.rs

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,6 @@ pub enum ConstantValue {
115115
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
116116
pub enum ConstExpression {
117117
Value(ConstantValue),
118-
Binary {
119-
left: Box<Self>,
120-
operation: HighLevelOperation,
121-
right: Box<Self>,
122-
},
123118
MathExpr(MathExpr, Vec<Self>),
124119
}
125120

@@ -140,11 +135,7 @@ impl TryFrom<Expression> for ConstExpression {
140135
Expression::Binary { left, operation, right } => {
141136
let left_expr = Self::try_from(*left)?;
142137
let right_expr = Self::try_from(*right)?;
143-
Ok(Self::Binary {
144-
left: Box::new(left_expr),
145-
operation,
146-
right: Box::new(right_expr),
147-
})
138+
Ok(Self::MathExpr(MathExpr::Binary(operation), vec![left_expr, right_expr]))
148139
}
149140
Expression::MathExpr(math_expr, args) => {
150141
let mut const_args = Vec::new();
@@ -185,9 +176,6 @@ impl ConstExpression {
185176
{
186177
match self {
187178
Self::Value(value) => func(value),
188-
Self::Binary { left, operation, right } => {
189-
Some(operation.eval(left.eval_with(func)?, right.eval_with(func)?))
190-
}
191179
Self::MathExpr(math_expr, args) => {
192180
let mut eval_args = Vec::new();
193181
for arg in args {
@@ -220,25 +208,16 @@ impl From<ConstantValue> for ConstExpression {
220208
}
221209
}
222210

223-
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
224-
pub enum AssumeBoolean {
225-
AssumeBoolean,
226-
DoNotAssumeBoolean,
227-
}
228-
229211
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
230212
pub enum Condition {
231-
Expression(Expression, AssumeBoolean),
213+
AssumeBoolean(Expression),
232214
Comparison(BooleanExpr<Expression>),
233215
}
234216

235217
impl Display for Condition {
236218
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
237219
match self {
238-
Self::Expression(expr, AssumeBoolean::AssumeBoolean) => {
239-
write!(f, "!assume_bool({expr})")
240-
}
241-
Self::Expression(expr, AssumeBoolean::DoNotAssumeBoolean) => write!(f, "{expr}"),
220+
Self::AssumeBoolean(expr) => write!(f, "{expr}"),
242221
Self::Comparison(cmp) => write!(f, "{cmp}"),
243222
}
244223
}
@@ -270,6 +249,7 @@ pub enum Expression {
270249
/// For arbitrary compile-time computations
271250
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
272251
pub enum MathExpr {
252+
Binary(HighLevelOperation),
273253
Log2Ceil,
274254
NextMultipleOf,
275255
SaturatingSub,
@@ -278,6 +258,7 @@ pub enum MathExpr {
278258
impl Display for MathExpr {
279259
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
280260
match self {
261+
Self::Binary(op) => write!(f, "{op}"),
281262
Self::Log2Ceil => write!(f, "log2_ceil"),
282263
Self::NextMultipleOf => write!(f, "next_multiple_of"),
283264
Self::SaturatingSub => write!(f, "saturating_sub"),
@@ -288,13 +269,18 @@ impl Display for MathExpr {
288269
impl MathExpr {
289270
pub fn num_args(&self) -> usize {
290271
match self {
272+
Self::Binary(_) => 2,
291273
Self::Log2Ceil => 1,
292274
Self::NextMultipleOf => 2,
293275
Self::SaturatingSub => 2,
294276
}
295277
}
296278
pub fn eval(&self, args: &[F]) -> F {
297279
match self {
280+
Self::Binary(op) => {
281+
assert_eq!(args.len(), 2);
282+
op.eval(args[0], args[1])
283+
}
298284
Self::Log2Ceil => {
299285
assert_eq!(args.len(), 1);
300286
let value = args[0];
@@ -722,9 +708,6 @@ impl Display for ConstExpression {
722708
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
723709
match self {
724710
Self::Value(value) => write!(f, "{value}"),
725-
Self::Binary { left, operation, right } => {
726-
write!(f, "({left} {operation} {right})")
727-
}
728711
Self::MathExpr(math_expr, args) => {
729712
let args_str = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>().join(", ");
730713
write!(f, "{math_expr}({args_str})")

0 commit comments

Comments
 (0)