diff --git a/crates/compiler/src/a_simplify_lang.rs b/crates/compiler/src/a_simplify_lang.rs index 4f119972..6e989196 100644 --- a/crates/compiler/src/a_simplify_lang.rs +++ b/crates/compiler/src/a_simplify_lang.rs @@ -1132,10 +1132,7 @@ fn handle_const_arguments_helper( for (arg_expr, (arg_var, is_constant)) in args.iter().zip(&func.arguments) { if *is_constant { let const_eval = arg_expr.naive_eval().unwrap_or_else(|| { - panic!( - "Failed to evaluate constant argument: {}", - arg_expr.to_string() - ) + panic!("Failed to evaluate constant argument: {arg_expr}") }); const_evals.push((arg_var.clone(), const_eval)); } @@ -1332,11 +1329,7 @@ impl ToString for VarOrConstMallocAccess { malloc_label, offset, } => { - format!( - "ConstMallocAccess({}, {})", - malloc_label, - offset.to_string() - ) + format!("ConstMallocAccess({malloc_label}, {offset})") } } } @@ -1352,28 +1345,17 @@ impl SimpleLine { arg0, arg1, } => { - format!( - "{} = {} {} {}", - var.to_string(), - arg0.to_string(), - operation.to_string(), - arg1.to_string() - ) + format!("{} = {} {} {}", var.to_string(), arg0, operation, arg1) } Self::DecomposeBits { var: result, to_decompose, label: _, } => { - format!("{} = decompose_bits({})", result, to_decompose.to_string()) + format!("{result} = decompose_bits({to_decompose})") } Self::RawAccess { res, index, shift } => { - format!( - "{} = memory[{} + {}]", - res.to_string(), - index, - shift.to_string() - ) + format!("{res} = memory[{index} + {shift}]") } Self::IfNotZero { condition, @@ -1393,20 +1375,10 @@ impl SimpleLine { .join("\n"); if else_branch.is_empty() { - format!( - "if {} != 0 {{\n{}\n{}}}", - condition.to_string(), - then_str, - spaces - ) + format!("if {condition} != 0 {{\n{then_str}\n{spaces}}}") } else { format!( - "if {} != 0 {{\n{}\n{}}} else {{\n{}\n{}}}", - condition.to_string(), - then_str, - spaces, - else_str, - spaces + "if {condition} != 0 {{\n{then_str}\n{spaces}}} else {{\n{else_str}\n{spaces}}}" ) } } @@ -1476,14 +1448,14 @@ impl SimpleLine { } else { "malloc" }; - format!("{} = {}({})", var, alloc_type, size.to_string()) + format!("{var} = {alloc_type}({size})") } Self::ConstMalloc { var, size, label: _, } => { - format!("{} = malloc({})", var, size.to_string()) + format!("{var} = malloc({size})") } Self::Panic => "panic".to_string(), }; diff --git a/crates/compiler/src/intermediate_bytecode.rs b/crates/compiler/src/intermediate_bytecode.rs index 8d4de9ac..32da4694 100644 --- a/crates/compiler/src/intermediate_bytecode.rs +++ b/crates/compiler/src/intermediate_bytecode.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap, fmt}; use p3_field::{PrimeCharacteristicRing, PrimeField64}; use vm::{Label, Operation}; @@ -73,6 +73,18 @@ impl HighLevelOperation { } } +impl fmt::Display for HighLevelOperation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Add => write!(f, "+"), + Self::Mul => write!(f, "*"), + Self::Sub => write!(f, "-"), + Self::Div => write!(f, "/"), + Self::Exp => write!(f, "**"), + } + } +} + #[derive(Debug, Clone)] pub(crate) enum IntermediateInstruction { Computation { @@ -191,7 +203,7 @@ impl ToString for IntermediateValue { Self::Constant(value) => value.to_string(), Self::Fp => "fp".to_string(), Self::MemoryAfterFp { offset } => { - format!("m[fp + {}]", offset.to_string()) + format!("m[fp + {offset}]") } } } @@ -200,7 +212,7 @@ impl ToString for IntermediateValue { impl ToString for IntermediaryMemOrFpOrConstant { fn to_string(&self) -> String { match self { - Self::MemoryAfterFp { offset } => format!("m[fp + {}]", offset.to_string()), + Self::MemoryAfterFp { offset } => format!("m[fp + {offset}]"), Self::Fp => "fp".to_string(), Self::Constant(c) => c.to_string(), } @@ -214,12 +226,7 @@ impl ToString for IntermediateInstruction { shift_0, shift_1, res, - } => format!( - "{} = m[m[fp + {}] + {}]", - res.to_string(), - shift_0.to_string(), - shift_1.to_string() - ), + } => format!("{} = m[m[fp + {}] + {}]", res.to_string(), shift_0, shift_1), Self::DotProduct { arg0, arg1, @@ -230,7 +237,7 @@ impl ToString for IntermediateInstruction { arg0.to_string(), arg1.to_string(), res.to_string(), - size.to_string() + size ), Self::MultilinearEval { coeffs, @@ -242,7 +249,7 @@ impl ToString for IntermediateInstruction { coeffs.to_string(), point.to_string(), res.to_string(), - n_vars.to_string() + n_vars ), Self::DecomposeBits { res_offset, @@ -318,7 +325,7 @@ impl ToString for IntermediateInstruction { vectorized, } => format!( "m[fp + {}] = {}({})", - offset.to_string(), + offset, if *vectorized { "malloc_vec" } else { "malloc" }, size.to_string(), ), @@ -335,18 +342,6 @@ impl ToString for IntermediateInstruction { } } -impl ToString for HighLevelOperation { - fn to_string(&self) -> String { - match self { - Self::Add => "+".to_string(), - Self::Mul => "*".to_string(), - Self::Sub => "-".to_string(), - Self::Div => "/".to_string(), - Self::Exp => "**".to_string(), - } - } -} - impl ToString for IntermediateBytecode { fn to_string(&self) -> String { let mut res = String::new(); diff --git a/crates/compiler/src/lang.rs b/crates/compiler/src/lang.rs deleted file mode 100644 index 7175766e..00000000 --- a/crates/compiler/src/lang.rs +++ /dev/null @@ -1,609 +0,0 @@ -use std::collections::BTreeMap; - -use p3_field::PrimeCharacteristicRing; -use utils::ToUsize; -use vm::Label; - -use crate::{F, intermediate_bytecode::HighLevelOperation, precompiles::Precompile}; - -#[derive(Debug, Clone)] -pub(crate) struct Program { - pub functions: BTreeMap, -} - -#[derive(Debug, Clone)] -pub(crate) struct Function { - pub name: String, - pub arguments: Vec<(Var, bool)>, // (name, is_const) - pub n_returned_vars: usize, - pub body: Vec, -} - -impl Function { - pub(crate) fn has_const_arguments(&self) -> bool { - self.arguments.iter().any(|(_, is_const)| *is_const) - } -} - -pub(crate) type Var = String; -pub(crate) type ConstMallocLabel = usize; - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub(crate) enum SimpleExpr { - Var(Var), - Constant(ConstExpression), - ConstMallocAccess { - malloc_label: ConstMallocLabel, - offset: ConstExpression, - }, -} - -impl SimpleExpr { - pub(crate) fn zero() -> Self { - Self::scalar(0) - } - - pub(crate) fn one() -> Self { - Self::scalar(1) - } - - pub(crate) fn scalar(scalar: usize) -> Self { - Self::Constant(ConstantValue::Scalar(scalar).into()) - } - - pub(crate) const fn is_constant(&self) -> bool { - matches!(self, Self::Constant(_)) - } - - pub(crate) fn simplify_if_const(&self) -> Self { - if let Self::Constant(constant) = self { - return constant.try_naive_simplification().into(); - } - self.clone() - } -} - -impl From for SimpleExpr { - fn from(constant: ConstantValue) -> Self { - Self::Constant(constant.into()) - } -} - -impl From for SimpleExpr { - fn from(constant: ConstExpression) -> Self { - Self::Constant(constant) - } -} - -impl From for SimpleExpr { - fn from(var: Var) -> Self { - Self::Var(var) - } -} - -impl SimpleExpr { - pub(crate) fn as_constant(&self) -> Option { - match self { - Self::Var(_) | Self::ConstMallocAccess { .. } => None, - Self::Constant(constant) => Some(constant.clone()), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub(crate) enum Boolean { - Equal { left: Expression, right: Expression }, - Different { left: Expression, right: Expression }, -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub(crate) enum ConstantValue { - Scalar(usize), - PublicInputStart, - PointerToZeroVector, // In the memory of chunks of 8 field elements - FunctionSize { function_name: Label }, - Label(Label), -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub(crate) enum ConstExpression { - Value(ConstantValue), - Binary { - left: Box, - operation: HighLevelOperation, - right: Box, - }, -} - -impl From for ConstExpression { - fn from(value: usize) -> Self { - Self::Value(ConstantValue::Scalar(value)) - } -} - -impl TryFrom for ConstExpression { - type Error = (); - - fn try_from(value: Expression) -> Result { - match value { - Expression::Value(SimpleExpr::Constant(const_expr)) => Ok(const_expr), - Expression::Value(_) | Expression::ArrayAccess { .. } => Err(()), - Expression::Binary { - left, - operation, - right, - } => { - let left_expr = Self::try_from(*left)?; - let right_expr = Self::try_from(*right)?; - Ok(Self::Binary { - left: Box::new(left_expr), - operation, - right: Box::new(right_expr), - }) - } - } - } -} - -impl ConstExpression { - pub(crate) const fn zero() -> Self { - Self::scalar(0) - } - - pub(crate) const fn one() -> Self { - Self::scalar(1) - } - - pub(crate) const fn label(label: Label) -> Self { - Self::Value(ConstantValue::Label(label)) - } - - pub(crate) const fn scalar(scalar: usize) -> Self { - Self::Value(ConstantValue::Scalar(scalar)) - } - - pub(crate) const fn function_size(function_name: Label) -> Self { - Self::Value(ConstantValue::FunctionSize { function_name }) - } - - pub(crate) fn eval_with(&self, func: &EvalFn) -> Option - where - EvalFn: Fn(&ConstantValue) -> Option, - { - match self { - Self::Value(value) => func(value), - Self::Binary { - left, - operation, - right, - } => Some(operation.eval(left.eval_with(func)?, right.eval_with(func)?)), - } - } - - pub(crate) fn naive_eval(&self) -> Option { - self.eval_with(&|value| match value { - ConstantValue::Scalar(scalar) => Some(F::from_usize(*scalar)), - _ => None, - }) - } - - pub(crate) fn try_naive_simplification(&self) -> Self { - self.naive_eval() - .map_or_else(|| self.clone(), |value| Self::scalar(value.to_usize())) - } -} - -impl From for ConstExpression { - fn from(value: ConstantValue) -> Self { - Self::Value(value) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub(crate) enum Expression { - Value(SimpleExpr), - ArrayAccess { - array: Var, - index: Box, - }, - Binary { - left: Box, - operation: HighLevelOperation, - right: Box, - }, -} - -impl From for Expression { - fn from(value: SimpleExpr) -> Self { - Self::Value(value) - } -} - -impl From for Expression { - fn from(var: Var) -> Self { - Self::Value(var.into()) - } -} - -impl Expression { - pub(crate) fn naive_eval(&self) -> Option { - self.eval_with( - &|value: &SimpleExpr| value.as_constant()?.naive_eval(), - &|_, _| None, - ) - } - - pub(crate) fn eval_with( - &self, - value_fn: &ValueFn, - array_fn: &ArrayFn, - ) -> Option - where - ValueFn: Fn(&SimpleExpr) -> Option, - ArrayFn: Fn(&Var, F) -> Option, - { - match self { - Self::Value(value) => value_fn(value), - Self::ArrayAccess { array, index } => { - array_fn(array, index.eval_with(value_fn, array_fn)?) - } - Self::Binary { - left, - operation, - right, - } => Some(operation.eval( - left.eval_with(value_fn, array_fn)?, - right.eval_with(value_fn, array_fn)?, - )), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub(crate) enum Line { - Assignment { - var: Var, - value: Expression, - }, - ArrayAssign { - // array[index] = value - array: Var, - index: Expression, - value: Expression, - }, - Assert(Boolean), - IfCondition { - condition: Boolean, - then_branch: Vec, - else_branch: Vec, - }, - ForLoop { - iterator: Var, - start: Expression, - end: Expression, - body: Vec, - rev: bool, - unroll: bool, - }, - FunctionCall { - function_name: String, - args: Vec, - return_data: Vec, - }, - FunctionRet { - return_data: Vec, - }, - Precompile { - precompile: Precompile, - args: Vec, - res: Vec, - }, - Break, - Panic, - // Hints: - Print { - line_info: String, - content: Vec, - }, - MAlloc { - var: Var, - size: Expression, - vectorized: bool, - }, - DecomposeBits { - var: Var, // a pointer to 31 field elements, containing the bits of "to_decompose" - to_decompose: Expression, - }, -} - -impl ToString for Expression { - fn to_string(&self) -> String { - match self { - Self::Value(val) => val.to_string(), - Self::ArrayAccess { array, index } => { - format!("{}[{}]", array, index.to_string()) - } - Self::Binary { - left, - operation, - right, - } => { - format!( - "({} {} {})", - left.to_string(), - operation.to_string(), - right.to_string() - ) - } - } - } -} - -impl Line { - fn to_string_with_indent(&self, indent: usize) -> String { - let spaces = " ".repeat(indent); - let line_str = match self { - Self::Assignment { var, value } => { - format!("{} = {}", var, value.to_string()) - } - Self::ArrayAssign { - array, - index, - value, - } => { - format!("{}[{}] = {}", array, index.to_string(), value.to_string()) - } - Self::Assert(condition) => format!("assert {}", condition.to_string()), - Self::IfCondition { - condition, - then_branch, - else_branch, - } => { - let then_str = then_branch - .iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n"); - - let else_str = else_branch - .iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n"); - - if else_branch.is_empty() { - format!( - "if {} {{\n{}\n{}}}", - condition.to_string(), - then_str, - spaces - ) - } else { - format!( - "if {} {{\n{}\n{}}} else {{\n{}\n{}}}", - condition.to_string(), - then_str, - spaces, - else_str, - spaces - ) - } - } - Self::ForLoop { - iterator, - start, - end, - body, - rev, - unroll, - } => { - let body_str = body - .iter() - .map(|line| line.to_string_with_indent(indent + 1)) - .collect::>() - .join("\n"); - format!( - "for {} in {}{}..{} {}{{\n{}\n{}}}", - iterator, - start.to_string(), - if *rev { "rev " } else { "" }, - end.to_string(), - if *unroll { "unroll " } else { "" }, - body_str, - spaces - ) - } - Self::FunctionCall { - function_name, - args, - return_data, - } => { - let args_str = args - .iter() - .map(std::string::ToString::to_string) - .collect::>() - .join(", "); - let return_data_str = return_data - .iter() - .map(std::string::ToString::to_string) - .collect::>() - .join(", "); - - if return_data.is_empty() { - format!("{function_name}({args_str})") - } else { - format!("{return_data_str} = {function_name}({args_str})") - } - } - Self::FunctionRet { return_data } => { - let return_data_str = return_data - .iter() - .map(std::string::ToString::to_string) - .collect::>() - .join(", "); - format!("return {return_data_str}") - } - Self::Precompile { - precompile, - args, - res: return_data, - } => { - format!( - "{} = {}({})", - return_data - .iter() - .map(std::string::ToString::to_string) - .collect::>() - .join(", "), - precompile.name.to_string(), - args.iter() - .map(std::string::ToString::to_string) - .collect::>() - .join(", ") - ) - } - Self::Print { - line_info: _, - content, - } => { - let content_str = content - .iter() - .map(std::string::ToString::to_string) - .collect::>() - .join(", "); - format!("print({content_str})") - } - Self::MAlloc { - var, - size, - vectorized, - } => { - if *vectorized { - format!("{} = malloc_vectorized({})", var, size.to_string()) - } else { - format!("{} = malloc({})", var, size.to_string()) - } - } - Self::DecomposeBits { var, to_decompose } => { - format!("{} = decompose_bits({})", var, to_decompose.to_string()) - } - Self::Break => "break".to_string(), - Self::Panic => "panic".to_string(), - }; - format!("{spaces}{line_str}") - } -} - -impl ToString for Boolean { - fn to_string(&self) -> String { - match self { - Self::Equal { left, right } => { - format!("{} == {}", left.to_string(), right.to_string()) - } - Self::Different { left, right } => { - format!("{} != {}", left.to_string(), right.to_string()) - } - } - } -} - -impl ToString for ConstantValue { - fn to_string(&self) -> String { - match self { - Self::Scalar(scalar) => scalar.to_string(), - Self::PublicInputStart => "@public_input_start".to_string(), - Self::PointerToZeroVector => "@pointer_to_zero_vector".to_string(), - Self::FunctionSize { function_name } => { - format!("@function_size_{function_name}") - } - Self::Label(label) => label.to_string(), - } - } -} - -impl ToString for SimpleExpr { - fn to_string(&self) -> String { - match self { - Self::Var(var) => var.to_string(), - Self::Constant(constant) => constant.to_string(), - Self::ConstMallocAccess { - malloc_label, - offset, - } => { - format!("malloc_access({}, {})", malloc_label, offset.to_string()) - } - } - } -} - -impl ToString for ConstExpression { - fn to_string(&self) -> String { - match self { - Self::Value(value) => value.to_string(), - Self::Binary { - left, - operation, - right, - } => { - format!( - "({} {} {})", - left.to_string(), - operation.to_string(), - right.to_string() - ) - } - } - } -} - -impl ToString for Line { - fn to_string(&self) -> String { - self.to_string_with_indent(0) - } -} - -impl ToString for Program { - fn to_string(&self) -> String { - let mut result = String::new(); - for (i, function) in self.functions.values().enumerate() { - if i > 0 { - result.push('\n'); - } - result.push_str(&function.to_string()); - } - result - } -} - -impl ToString for Function { - fn to_string(&self) -> String { - let args_str = self - .arguments - .iter() - .map(|arg| match arg { - (name, true) => format!("const {name}"), - (name, false) => name.to_string(), - }) - .collect::>() - .join(", "); - - let instructions_str = self - .body - .iter() - .map(|line| line.to_string_with_indent(1)) - .collect::>() - .join("\n"); - - if self.body.is_empty() { - format!( - "fn {}({}) -> {} {{}}", - self.name, args_str, self.n_returned_vars - ) - } else { - format!( - "fn {}({}) -> {} {{\n{}\n}}", - self.name, args_str, self.n_returned_vars, instructions_str - ) - } - } -} diff --git a/crates/compiler/src/lang/boolean.rs b/crates/compiler/src/lang/boolean.rs new file mode 100644 index 00000000..03bdd7c5 --- /dev/null +++ b/crates/compiler/src/lang/boolean.rs @@ -0,0 +1,18 @@ +use std::fmt; + +use crate::lang::expression::Expression; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Boolean { + Equal { left: Expression, right: Expression }, + Different { left: Expression, right: Expression }, +} + +impl fmt::Display for Boolean { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Equal { left, right } => write!(f, "{left} == {right}"), + Self::Different { left, right } => write!(f, "{left} != {right}"), + } + } +} diff --git a/crates/compiler/src/lang/const_expr.rs b/crates/compiler/src/lang/const_expr.rs new file mode 100644 index 00000000..a894cb02 --- /dev/null +++ b/crates/compiler/src/lang/const_expr.rs @@ -0,0 +1,140 @@ +use std::fmt; + +use p3_field::{PrimeCharacteristicRing, PrimeField}; +use vm::F; + +use crate::{ + Compiler, + intermediate_bytecode::HighLevelOperation, + lang::{Label, constant_value::ConstantValue, expression::Expression, simple_expr::SimpleExpr}, +}; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ConstExpression { + Value(ConstantValue), + Binary { + left: Box, + operation: HighLevelOperation, + right: Box, + }, +} + +impl ConstExpression { + #[must_use] + pub const fn zero() -> Self { + Self::scalar(0) + } + + #[must_use] + pub const fn one() -> Self { + Self::scalar(1) + } + + #[must_use] + pub const fn label(label: Label) -> Self { + Self::Value(ConstantValue::Label(label)) + } + + #[must_use] + pub const fn scalar(scalar: usize) -> Self { + Self::Value(ConstantValue::Scalar(scalar)) + } + + #[must_use] + pub const fn function_size(function_name: Label) -> Self { + Self::Value(ConstantValue::FunctionSize { function_name }) + } + + pub fn eval_with(&self, func: &EvalFn) -> Option + where + EvalFn: Fn(&ConstantValue) -> Option, + { + match self { + Self::Value(value) => func(value), + Self::Binary { + left, + operation, + right, + } => Some(operation.eval(left.eval_with(func)?, right.eval_with(func)?)), + } + } + + #[must_use] + pub fn naive_eval(&self) -> Option { + self.eval_with(&|value| match value { + ConstantValue::Scalar(scalar) => Some(F::from_usize(*scalar)), + _ => None, + }) + } + + #[must_use] + pub fn try_naive_simplification(&self) -> Self { + self.naive_eval().map_or_else( + || self.clone(), + |value| Self::scalar(value.as_canonical_biguint().try_into().unwrap()), + ) + } + + #[must_use] + pub fn eval(&self, compiler: &Compiler) -> F { + self.eval_with(&|cst| Some(F::from_usize(cst.eval(compiler)))) + .unwrap() + } + + #[must_use] + pub fn eval_usize(&self, compiler: &Compiler) -> usize { + self.eval(compiler) + .as_canonical_biguint() + .try_into() + .unwrap() + } +} + +impl From for ConstExpression { + fn from(value: usize) -> Self { + Self::Value(ConstantValue::Scalar(value)) + } +} + +impl TryFrom for ConstExpression { + type Error = (); + + fn try_from(value: Expression) -> Result { + match value { + Expression::Value(SimpleExpr::Constant(const_expr)) => Ok(const_expr), + Expression::Value(_) | Expression::ArrayAccess { .. } => Err(()), + Expression::Binary { + left, + operation, + right, + } => { + let left_expr = Self::try_from(*left)?; + let right_expr = Self::try_from(*right)?; + Ok(Self::Binary { + left: Box::new(left_expr), + operation, + right: Box::new(right_expr), + }) + } + } + } +} + +impl From for ConstExpression { + fn from(value: ConstantValue) -> Self { + Self::Value(value) + } +} + +impl fmt::Display for ConstExpression { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Value(value) => write!(f, "{value}"), + Self::Binary { + left, + operation, + right, + } => write!(f, "({left} {operation} {right})"), + } + } +} diff --git a/crates/compiler/src/lang/constant_value.rs b/crates/compiler/src/lang/constant_value.rs new file mode 100644 index 00000000..e9a873e5 --- /dev/null +++ b/crates/compiler/src/lang/constant_value.rs @@ -0,0 +1,44 @@ +use std::fmt; + +use vm::{PUBLIC_INPUT_START, ZERO_VEC_PTR}; + +use crate::{Compiler, lang::Label}; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ConstantValue { + Scalar(usize), + PublicInputStart, + PointerToZeroVector, // In the memory of chunks of 8 field elements + FunctionSize { function_name: Label }, + Label(Label), +} + +impl ConstantValue { + #[must_use] + pub fn eval(&self, compiler: &Compiler) -> usize { + match self { + Self::Scalar(scalar) => *scalar, + Self::PublicInputStart => PUBLIC_INPUT_START, + Self::PointerToZeroVector => ZERO_VEC_PTR, + Self::FunctionSize { function_name } => *compiler + .memory_size_per_function + .get(function_name) + .unwrap_or_else(|| panic!("Function {function_name} not found in memory size map")), + Self::Label(label) => compiler.label_to_pc.get(label).copied().unwrap(), + } + } +} + +impl fmt::Display for ConstantValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Scalar(scalar) => write!(f, "{scalar}"), + Self::PublicInputStart => write!(f, "@public_input_start"), + Self::PointerToZeroVector => write!(f, "@pointer_to_zero_vector"), + Self::FunctionSize { function_name } => { + write!(f, "@function_size_{function_name}") + } + Self::Label(label) => write!(f, "{label}"), + } + } +} diff --git a/crates/compiler/src/lang/expression.rs b/crates/compiler/src/lang/expression.rs new file mode 100644 index 00000000..816210c5 --- /dev/null +++ b/crates/compiler/src/lang/expression.rs @@ -0,0 +1,79 @@ +use std::fmt; + +use vm::F; + +use crate::{ + intermediate_bytecode::HighLevelOperation, + lang::{Var, simple_expr::SimpleExpr}, +}; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Expression { + Value(SimpleExpr), + ArrayAccess { + array: Var, + index: Box, + }, + Binary { + left: Box, + operation: HighLevelOperation, + right: Box, + }, +} + +impl Expression { + #[must_use] + pub fn naive_eval(&self) -> Option { + self.eval_with( + &|value: &SimpleExpr| value.as_constant()?.naive_eval(), + &|_, _| None, + ) + } + + pub fn eval_with(&self, value_fn: &ValueFn, array_fn: &ArrayFn) -> Option + where + ValueFn: Fn(&SimpleExpr) -> Option, + ArrayFn: Fn(&Var, F) -> Option, + { + match self { + Self::Value(value) => value_fn(value), + Self::ArrayAccess { array, index } => { + array_fn(array, index.eval_with(value_fn, array_fn)?) + } + Self::Binary { + left, + operation, + right, + } => Some(operation.eval( + left.eval_with(value_fn, array_fn)?, + right.eval_with(value_fn, array_fn)?, + )), + } + } +} + +impl From for Expression { + fn from(value: SimpleExpr) -> Self { + Self::Value(value) + } +} + +impl From for Expression { + fn from(var: Var) -> Self { + Self::Value(var.into()) + } +} + +impl fmt::Display for Expression { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Value(val) => write!(f, "{val}"), + Self::ArrayAccess { array, index } => write!(f, "{array}[{index}]"), + Self::Binary { + left, + operation, + right, + } => write!(f, "({left} {operation} {right})"), + } + } +} diff --git a/crates/compiler/src/lang/function.rs b/crates/compiler/src/lang/function.rs new file mode 100644 index 00000000..4e96bc49 --- /dev/null +++ b/crates/compiler/src/lang/function.rs @@ -0,0 +1,53 @@ +use std::fmt; + +use crate::lang::{Var, line::Line}; + +#[derive(Debug, Clone)] +pub struct Function { + pub name: String, + pub arguments: Vec<(Var, bool)>, // (name, is_const) + pub n_returned_vars: usize, + pub body: Vec, +} + +impl Function { + #[must_use] + pub fn has_const_arguments(&self) -> bool { + self.arguments.iter().any(|(_, is_const)| *is_const) + } +} + +impl fmt::Display for Function { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let args_str = self + .arguments + .iter() + .map(|arg| match arg { + (name, true) => format!("const {name}"), + (name, false) => name.clone(), + }) + .collect::>() + .join(", "); + + let instructions_str = self + .body + .iter() + .map(|line| line.to_string_with_indent(1)) + .collect::>() + .join("\n"); + + if self.body.is_empty() { + write!( + f, + "fn {}({}) -> {} {{}}", + self.name, args_str, self.n_returned_vars + ) + } else { + write!( + f, + "fn {}({}) -> {} {{\n{}\n}}", + self.name, args_str, self.n_returned_vars, instructions_str + ) + } + } +} diff --git a/crates/compiler/src/lang/line.rs b/crates/compiler/src/lang/line.rs new file mode 100644 index 00000000..d75fea0e --- /dev/null +++ b/crates/compiler/src/lang/line.rs @@ -0,0 +1,214 @@ +use std::fmt; + +use crate::{ + lang::{Var, boolean::Boolean, expression::Expression}, + precompiles::Precompile, +}; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Line { + Assignment { + var: Var, + value: Expression, + }, + ArrayAssign { + // array[index] = value + array: Var, + index: Expression, + value: Expression, + }, + Assert(Boolean), + IfCondition { + condition: Boolean, + then_branch: Vec, + else_branch: Vec, + }, + ForLoop { + iterator: Var, + start: Expression, + end: Expression, + body: Vec, + rev: bool, + unroll: bool, + }, + FunctionCall { + function_name: String, + args: Vec, + return_data: Vec, + }, + FunctionRet { + return_data: Vec, + }, + Precompile { + precompile: Precompile, + args: Vec, + res: Vec, + }, + Break, + Panic, + // Hints: + Print { + line_info: String, + content: Vec, + }, + MAlloc { + var: Var, + size: Expression, + vectorized: bool, + }, + DecomposeBits { + var: Var, // a pointer to 31 field elements, containing the bits of "to_decompose" + to_decompose: Expression, + }, +} + +impl Line { + pub(crate) fn to_string_with_indent(&self, indent: usize) -> String { + let spaces = " ".repeat(indent); + let line_str = match self { + Self::Assignment { var, value } => { + format!("{var} = {value}") + } + Self::ArrayAssign { + array, + index, + value, + } => { + format!("{array}[{index}] = {value}") + } + Self::Assert(condition) => format!("assert {condition}"), + Self::IfCondition { + condition, + then_branch, + else_branch, + } => { + let then_str = then_branch + .iter() + .map(|line| line.to_string_with_indent(indent + 1)) + .collect::>() + .join("\n"); + + let else_str = else_branch + .iter() + .map(|line| line.to_string_with_indent(indent + 1)) + .collect::>() + .join("\n"); + + if else_branch.is_empty() { + format!("if {condition} {{\n{then_str}\n{spaces}}}") + } else { + format!( + "if {condition} {{\n{then_str}\n{spaces}}} else {{\n{else_str}\n{spaces}}}" + ) + } + } + Self::ForLoop { + iterator, + start, + end, + body, + rev, + unroll, + } => { + let body_str = body + .iter() + .map(|line| line.to_string_with_indent(indent + 1)) + .collect::>() + .join("\n"); + format!( + "for {} in {}{}..{} {}{{\n{}\n{}}}", + iterator, + start, + if *rev { "rev " } else { "" }, + end, + if *unroll { "unroll " } else { "" }, + body_str, + spaces + ) + } + Self::FunctionCall { + function_name, + args, + return_data, + } => { + let args_str = args + .iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(", "); + let return_data_str = return_data + .iter() + .map(|var| var.to_string()) + .collect::>() + .join(", "); + + if return_data.is_empty() { + format!("{function_name}({args_str})") + } else { + format!("{return_data_str} = {function_name}({args_str})") + } + } + Self::FunctionRet { return_data } => { + let return_data_str = return_data + .iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(", "); + format!("return {return_data_str}") + } + Self::Precompile { + precompile, + args, + res: return_data, + } => { + format!( + "{} = {}({})", + return_data + .iter() + .map(|var| var.to_string()) + .collect::>() + .join(", "), + precompile.name, + args.iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(", ") + ) + } + Self::Print { + line_info: _, + content, + } => { + let content_str = content + .iter() + .map(|c| c.to_string()) + .collect::>() + .join(", "); + format!("print({content_str})") + } + Self::MAlloc { + var, + size, + vectorized, + } => { + if *vectorized { + format!("{var} = malloc_vectorized({size})") + } else { + format!("{var} = malloc({size})") + } + } + Self::DecomposeBits { var, to_decompose } => { + format!("{var} = decompose_bits({to_decompose})") + } + Self::Break => "break".to_string(), + Self::Panic => "panic".to_string(), + }; + format!("{spaces}{line_str}") + } +} + +impl fmt::Display for Line { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.to_string_with_indent(0)) + } +} diff --git a/crates/compiler/src/lang/mod.rs b/crates/compiler/src/lang/mod.rs new file mode 100644 index 00000000..ad1176e9 --- /dev/null +++ b/crates/compiler/src/lang/mod.rs @@ -0,0 +1,20 @@ +pub mod boolean; +pub use boolean::*; +pub mod const_expr; +pub use const_expr::*; +pub mod constant_value; +pub use constant_value::*; +pub mod expression; +pub use expression::*; +pub mod function; +pub use function::*; +pub mod line; +pub use line::*; +pub mod program; +pub use program::*; +pub mod simple_expr; +pub use simple_expr::*; + +pub type Var = String; +pub type ConstMallocLabel = usize; +pub type Label = String; diff --git a/crates/compiler/src/lang/program.rs b/crates/compiler/src/lang/program.rs new file mode 100644 index 00000000..b0a7f8f8 --- /dev/null +++ b/crates/compiler/src/lang/program.rs @@ -0,0 +1,22 @@ +use std::{collections::BTreeMap, fmt}; + +use crate::lang::function::Function; + +#[derive(Debug, Clone)] +pub struct Program { + pub functions: BTreeMap, +} + +impl fmt::Display for Program { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let program_str = self + .functions + .values() + .map(ToString::to_string) + .collect::>() + .join("\n\n"); + + // Write the final, joined string to the formatter. + write!(f, "{program_str}") + } +} diff --git a/crates/compiler/src/lang/simple_expr.rs b/crates/compiler/src/lang/simple_expr.rs new file mode 100644 index 00000000..d9f20854 --- /dev/null +++ b/crates/compiler/src/lang/simple_expr.rs @@ -0,0 +1,84 @@ +use std::fmt; + +use crate::lang::{ + ConstMallocLabel, Var, const_expr::ConstExpression, constant_value::ConstantValue, +}; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum SimpleExpr { + Var(Var), + Constant(ConstExpression), + ConstMallocAccess { + malloc_label: ConstMallocLabel, + offset: ConstExpression, + }, +} + +impl SimpleExpr { + #[must_use] + pub fn zero() -> Self { + Self::scalar(0) + } + + #[must_use] + pub fn one() -> Self { + Self::scalar(1) + } + + #[must_use] + pub fn scalar(scalar: usize) -> Self { + Self::Constant(ConstantValue::Scalar(scalar).into()) + } + + #[must_use] + pub const fn is_constant(&self) -> bool { + matches!(self, Self::Constant(_)) + } + + #[must_use] + pub fn simplify_if_const(&self) -> Self { + if let Self::Constant(constant) = self { + return constant.try_naive_simplification().into(); + } + self.clone() + } + + #[must_use] + pub fn as_constant(&self) -> Option { + match self { + Self::Var(_) | Self::ConstMallocAccess { .. } => None, + Self::Constant(constant) => Some(constant.clone()), + } + } +} + +impl From for SimpleExpr { + fn from(constant: ConstantValue) -> Self { + Self::Constant(constant.into()) + } +} + +impl From for SimpleExpr { + fn from(constant: ConstExpression) -> Self { + Self::Constant(constant) + } +} + +impl From for SimpleExpr { + fn from(var: Var) -> Self { + Self::Var(var) + } +} + +impl fmt::Display for SimpleExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Var(var) => write!(f, "{var}"), + Self::Constant(constant) => write!(f, "{constant}"), + Self::ConstMallocAccess { + malloc_label, + offset, + } => write!(f, "malloc_access({malloc_label}, {offset})"), + } + } +} diff --git a/crates/compiler/src/lib.rs b/crates/compiler/src/lib.rs index c27ff07b..97ebca3c 100644 --- a/crates/compiler/src/lib.rs +++ b/crates/compiler/src/lib.rs @@ -1,4 +1,4 @@ -use vm::{Bytecode, F, PUBLIC_INPUT_START, ZERO_VEC_PTR, execute_bytecode}; +use vm::{Bytecode, F, Label, PUBLIC_INPUT_START, ZERO_VEC_PTR, execute_bytecode}; use crate::{ a_simplify_lang::simplify_program, b_compile_intermediate::compile_to_intermediate_bytecode, @@ -12,8 +12,16 @@ mod intermediate_bytecode; mod lang; mod parser; mod precompiles; +use std::collections::BTreeMap; + pub use precompiles::PRECOMPILES; +#[derive(Debug)] +pub struct Compiler { + pub memory_size_per_function: BTreeMap, + pub label_to_pc: BTreeMap, +} + #[must_use] pub fn compile_program(program: &str) -> Bytecode { let parsed_program = parse_program(program).unwrap(); diff --git a/crates/compiler/src/parser.rs b/crates/compiler/src/parser.rs index 20cb1c6d..b58630ac 100644 --- a/crates/compiler/src/parser.rs +++ b/crates/compiler/src/parser.rs @@ -638,7 +638,7 @@ fn my_function1(a, const b, c) -> 2 { "; let parsed = parse_program(program).unwrap(); - println!("{}", parsed.to_string()); + println!("{parsed}"); } #[test] @@ -651,7 +651,7 @@ fn test_func(const a, b, const c) -> 1 { "; let parsed = parse_program(program).unwrap(); - println!("{}", parsed.to_string()); + println!("{parsed}"); } #[test] @@ -667,6 +667,6 @@ fn test_exp() -> 1 { "; let parsed = parse_program(program).unwrap(); - println!("{}", parsed.to_string()); + println!("{parsed}"); } } diff --git a/crates/compiler/src/precompiles.rs b/crates/compiler/src/precompiles.rs index 9e53d05d..e38b1607 100644 --- a/crates/compiler/src/precompiles.rs +++ b/crates/compiler/src/precompiles.rs @@ -1,3 +1,5 @@ +use std::fmt; + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Precompile { pub name: PrecompileName, @@ -13,15 +15,22 @@ pub enum PrecompileName { MultilinearEval, } -impl ToString for PrecompileName { - fn to_string(&self) -> String { +impl PrecompileName { + /// Returns the string representation of the precompile name. + #[must_use] + pub const fn as_str(&self) -> &'static str { match self { Self::Poseidon16 => "poseidon16", Self::Poseidon24 => "poseidon24", Self::DotProduct => "dot_product", Self::MultilinearEval => "multilinear_eval", } - .to_string() + } +} + +impl fmt::Display for PrecompileName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) } }