Skip to content

Commit cad491f

Browse files
committed
simplify conditions
1 parent 88f3464 commit cad491f

File tree

9 files changed

+73
-172
lines changed

9 files changed

+73
-172
lines changed

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::{
33
ir::HighLevelOperation,
44
lang::{
55
AssignmentTarget, AssumeBoolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue, Context,
6-
Expression, Function, Line, Program, Scope, SimpleExpr, Var,
6+
Expression, Function, Line, MathExpr, Program, Scope, SimpleExpr, Var,
77
},
88
parser::ConstArrayValue,
99
};
@@ -723,11 +723,10 @@ fn simplify_lines(
723723
if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) =
724724
(&left, &right)
725725
{
726-
let result = ConstExpression::Binary {
727-
left: Box::new(left_cst.clone()),
728-
operation: *operation,
729-
right: Box::new(right_cst.clone()),
730-
}
726+
let result = ConstExpression::MathExpr(
727+
MathExpr::Binary(*operation),
728+
vec![left_cst.clone(), right_cst.clone()],
729+
)
731730
.try_naive_simplification();
732731
res.push(SimpleLine::equality(var.clone(), SimpleExpr::Constant(result)));
733732
} else {
@@ -1194,11 +1193,10 @@ fn simplify_expr(
11941193
let right_var = simplify_expr(ctx, state, const_malloc, right, lines);
11951194

11961195
if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = (&left_var, &right_var) {
1197-
return SimpleExpr::Constant(ConstExpression::Binary {
1198-
left: Box::new(left_cst.clone()),
1199-
operation: *operation,
1200-
right: Box::new(right_cst.clone()),
1201-
});
1196+
return SimpleExpr::Constant(ConstExpression::MathExpr(
1197+
MathExpr::Binary(*operation),
1198+
vec![left_cst.clone(), right_cst.clone()],
1199+
));
12021200
}
12031201

12041202
let aux_var = state.counters.aux_var();

crates/lean_compiler/src/b_compile_intermediate.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ impl Compiler {
5151
VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset } => {
5252
for scope in self.stack_frame_layout.scopes.iter().rev() {
5353
if let Some(base) = scope.const_mallocs.get(malloc_label) {
54-
return ConstExpression::Binary {
55-
left: Box::new((*base).into()),
56-
operation: HighLevelOperation::Add,
57-
right: Box::new((*offset).clone()),
58-
};
54+
return ConstExpression::MathExpr(
55+
MathExpr::Binary(HighLevelOperation::Add),
56+
vec![(*base).into(), offset.clone()],
57+
);
5958
}
6059
}
6160
panic!("Const malloc {malloc_label} not in scope");
@@ -518,7 +517,7 @@ fn compile_lines(
518517
});
519518
}
520519
SimpleLine::ConstMalloc { var, size, label } => {
521-
let size = size.naive_eval().unwrap().to_usize(); // TODO not very good;
520+
let size = size.naive_eval().unwrap().to_usize();
522521
if !compiler.is_in_scope(var) {
523522
let current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap();
524523
current_scope_layout

crates/lean_compiler/src/grammar.pest

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,8 @@ statement = {
2929
return_statement |
3030
break_statement |
3131
continue_statement |
32-
assert_eq_statement |
33-
assert_not_eq_statement |
34-
debug_assert_eq_statement |
35-
debug_assert_not_eq_statement |
36-
debug_assert_lt_statement |
32+
assert_statement |
33+
debug_assert_statement |
3734
assignment
3835
}
3936

@@ -51,10 +48,14 @@ assignment_target = { array_access_expr | identifier }
5148

5249
if_statement = { "if" ~ condition ~ "{" ~ statement* ~ "}" ~ else_if_clause* ~ else_clause? }
5350

54-
condition = { expression | assumed_bool_expr }
51+
condition = { assumed_bool_expr | comparison }
5552

5653
assumed_bool_expr = { "!!assume_bool" ~ "(" ~ expression ~ ")" }
5754

55+
// Comparisons (shared between conditions and assertions)
56+
comparison = { add_expr ~ comparison_op ~ add_expr }
57+
comparison_op = { "==" | "!=" | "<" }
58+
5859
else_if_clause = { "else" ~ "if" ~ condition ~ "{" ~ statement* ~ "}" }
5960

6061
else_clause = { "else" ~ "{" ~ statement* ~ "}" }
@@ -67,18 +68,12 @@ match_statement = { "match" ~ expression ~ "{" ~ match_arm* ~ "}" }
6768
match_arm = { pattern ~ "=>" ~ "{" ~ statement* ~ "}" }
6869
pattern = { constant_value }
6970

70-
assert_eq_statement = { "assert" ~ add_expr ~ "==" ~ add_expr ~ ";" }
71-
assert_not_eq_statement = { "assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" }
72-
73-
debug_assert_eq_statement = { "debug_assert" ~ add_expr ~ "==" ~ add_expr ~ ";" }
74-
debug_assert_not_eq_statement = { "debug_assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" }
75-
debug_assert_lt_statement = { "debug_assert" ~ add_expr ~ "<" ~ add_expr ~ ";" }
71+
assert_statement = { "assert" ~ comparison ~ ";" }
72+
debug_assert_statement = { "debug_assert" ~ comparison ~ ";" }
7673

7774
// Expressions
7875
tuple_expression = { expression ~ ("," ~ expression)* }
79-
expression = { neq_expr }
80-
neq_expr = { eq_expr ~ ("!=" ~ eq_expr)* }
81-
eq_expr = { add_expr ~ ("==" ~ add_expr)* }
76+
expression = { add_expr }
8277
add_expr = { sub_expr ~ ("+" ~ sub_expr)* }
8378
sub_expr = { mul_expr ~ ("-" ~ mul_expr)* }
8479
mul_expr = { mod_expr ~ ("*" ~ mod_expr)* }

crates/lean_compiler/src/ir/instruction.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,7 @@ impl IntermediateInstruction {
9494
arg_c,
9595
res: arg_a,
9696
},
97-
HighLevelOperation::Exp
98-
| HighLevelOperation::Mod
99-
| HighLevelOperation::Equal
100-
| HighLevelOperation::NotEqual => unreachable!(),
97+
HighLevelOperation::Exp | HighLevelOperation::Mod => unreachable!(),
10198
}
10299
}
103100

crates/lean_compiler/src/ir/operation.rs

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@ use multilinear_toolkit::prelude::*;
44
use std::fmt::{Display, Formatter};
55
use utils::ToUsize;
66

7-
/// High-level operations that can be performed in the IR.
8-
///
9-
/// These operations represent the semantic intent of computations
10-
/// and may be lowered to different VM operations depending on context.
117
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
128
pub enum HighLevelOperation {
139
/// Addition operation.
@@ -22,29 +18,11 @@ pub enum HighLevelOperation {
2218
Exp,
2319
/// Modulo operation (only for constant expressions).
2420
Mod,
25-
/// Equality comparison
26-
Equal,
27-
/// Non-equality comparison
28-
NotEqual,
2921
}
3022

3123
impl HighLevelOperation {
3224
pub fn eval(&self, a: F, b: F) -> F {
3325
match self {
34-
Self::Equal => {
35-
if a == b {
36-
F::ONE
37-
} else {
38-
F::ZERO
39-
}
40-
}
41-
Self::NotEqual => {
42-
if a != b {
43-
F::ONE
44-
} else {
45-
F::ZERO
46-
}
47-
}
4826
Self::Add => a + b,
4927
Self::Mul => a * b,
5028
Self::Sub => a - b,
@@ -58,8 +36,6 @@ impl HighLevelOperation {
5836
impl Display for HighLevelOperation {
5937
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
6038
match self {
61-
Self::Equal => write!(f, "=="),
62-
Self::NotEqual => write!(f, "!="),
6339
Self::Add => write!(f, "+"),
6440
Self::Mul => write!(f, "*"),
6541
Self::Sub => write!(f, "-"),

crates/lean_compiler/src/lang.rs

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,6 @@ pub enum ConstantValue {
115115
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
116116
pub enum ConstExpression {
117117
Value(ConstantValue),
118-
Binary {
119-
left: Box<Self>,
120-
operation: HighLevelOperation,
121-
right: Box<Self>,
122-
},
123118
MathExpr(MathExpr, Vec<Self>),
124119
}
125120

@@ -140,11 +135,7 @@ impl TryFrom<Expression> for ConstExpression {
140135
Expression::Binary { left, operation, right } => {
141136
let left_expr = Self::try_from(*left)?;
142137
let right_expr = Self::try_from(*right)?;
143-
Ok(Self::Binary {
144-
left: Box::new(left_expr),
145-
operation,
146-
right: Box::new(right_expr),
147-
})
138+
Ok(Self::MathExpr(MathExpr::Binary(operation), vec![left_expr, right_expr]))
148139
}
149140
Expression::MathExpr(math_expr, args) => {
150141
let mut const_args = Vec::new();
@@ -185,9 +176,6 @@ impl ConstExpression {
185176
{
186177
match self {
187178
Self::Value(value) => func(value),
188-
Self::Binary { left, operation, right } => {
189-
Some(operation.eval(left.eval_with(func)?, right.eval_with(func)?))
190-
}
191179
Self::MathExpr(math_expr, args) => {
192180
let mut eval_args = Vec::new();
193181
for arg in args {
@@ -270,6 +258,7 @@ pub enum Expression {
270258
/// For arbitrary compile-time computations
271259
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
272260
pub enum MathExpr {
261+
Binary(HighLevelOperation),
273262
Log2Ceil,
274263
NextMultipleOf,
275264
SaturatingSub,
@@ -278,6 +267,7 @@ pub enum MathExpr {
278267
impl Display for MathExpr {
279268
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
280269
match self {
270+
Self::Binary(op) => write!(f, "{op}"),
281271
Self::Log2Ceil => write!(f, "log2_ceil"),
282272
Self::NextMultipleOf => write!(f, "next_multiple_of"),
283273
Self::SaturatingSub => write!(f, "saturating_sub"),
@@ -288,13 +278,18 @@ impl Display for MathExpr {
288278
impl MathExpr {
289279
pub fn num_args(&self) -> usize {
290280
match self {
281+
Self::Binary(_) => 2,
291282
Self::Log2Ceil => 1,
292283
Self::NextMultipleOf => 2,
293284
Self::SaturatingSub => 2,
294285
}
295286
}
296287
pub fn eval(&self, args: &[F]) -> F {
297288
match self {
289+
Self::Binary(op) => {
290+
assert_eq!(args.len(), 2);
291+
op.eval(args[0], args[1])
292+
}
298293
Self::Log2Ceil => {
299294
assert_eq!(args.len(), 1);
300295
let value = args[0];
@@ -722,9 +717,6 @@ impl Display for ConstExpression {
722717
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
723718
match self {
724719
Self::Value(value) => write!(f, "{value}"),
725-
Self::Binary { left, operation, right } => {
726-
write!(f, "({left} {operation} {right})")
727-
}
728720
Self::MathExpr(math_expr, args) => {
729721
let args_str = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>().join(", ");
730722
write!(f, "{math_expr}({args_str})")

crates/lean_compiler/src/parser/parsers/expression.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@ use crate::{
55
ir::HighLevelOperation,
66
lang::{ConstExpression, ConstantValue, Expression, SimpleExpr},
77
parser::{
8-
error::{ParseError, ParseResult, SemanticError},
8+
error::{ParseResult, SemanticError},
99
grammar::{ParsePair, Rule},
1010
},
1111
};
1212

13-
/// Parser for all expression types.
1413
pub struct ExpressionParser;
1514

1615
impl Parse<Expression> for ExpressionParser {
@@ -20,23 +19,18 @@ impl Parse<Expression> for ExpressionParser {
2019
let inner = next_inner_pair(&mut pair.into_inner(), "expression body")?;
2120
Self.parse(inner, ctx)
2221
}
23-
Rule::neq_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::NotEqual),
24-
Rule::eq_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Equal),
2522
Rule::add_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Add),
2623
Rule::sub_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Sub),
2724
Rule::mul_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Mul),
2825
Rule::mod_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Mod),
2926
Rule::div_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Div),
3027
Rule::exp_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Exp),
3128
Rule::primary => PrimaryExpressionParser.parse(pair, ctx),
32-
other_rule => Err(ParseError::SemanticError(SemanticError::new(format!(
33-
"ExpressionParser: Unexpected rule {other_rule:?}"
34-
)))),
29+
other_rule => Err(SemanticError::new(format!("ExpressionParser: Unexpected rule {other_rule:?}")).into()),
3530
}
3631
}
3732
}
3833

39-
/// Parser for binary arithmetic operations.
4034
pub struct BinaryExpressionParser;
4135

4236
impl BinaryExpressionParser {

0 commit comments

Comments
 (0)