diff --git a/prusti-encoder/src/encoders/const.rs b/prusti-encoder/src/encoders/const.rs index aaf8b1a4a7b..4f6dac76d47 100644 --- a/prusti-encoder/src/encoders/const.rs +++ b/prusti-encoder/src/encoders/const.rs @@ -11,6 +11,19 @@ use vir::{CallableIdent, Arity}; pub struct ConstEnc; + +#[derive(Clone)] +pub struct ConstEncOutput<'vir>(pub vir::Expr<'vir>); + + +impl<'vir> task_encoder::Optimizable for ConstEncOutput<'vir> {} + +impl<'vir> From> for ConstEncOutput<'vir> { + fn from(value: vir::Expr<'vir>) -> Self { + Self(value) + } +} + #[derive(Clone, Debug)] pub struct ConstEncOutputRef<'vir> { pub base_name: String, @@ -28,7 +41,7 @@ impl TaskEncoder for ConstEnc { usize, // current encoding depth DefId, // DefId of the current function ); - type OutputFullLocal<'vir> = vir::Expr<'vir>; + type OutputFullLocal<'vir> = ConstEncOutput<'vir>; type EncodingError = (); fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir> { @@ -94,6 +107,6 @@ impl TaskEncoder for ConstEnc { }), mir::ConstantKind::Ty(_) => todo!(), }; - Ok((res, ())) + Ok((res.into(), ())) } } diff --git a/prusti-encoder/src/encoders/generic.rs b/prusti-encoder/src/encoders/generic.rs index ec380e7538a..f46e127dc47 100644 --- a/prusti-encoder/src/encoders/generic.rs +++ b/prusti-encoder/src/encoders/generic.rs @@ -26,6 +26,9 @@ pub struct GenericEncOutput<'vir> { pub domain_type: vir::Domain<'vir>, } +impl<'vir> task_encoder::Optimizable for GenericEncOutput<'vir> {} + + impl TaskEncoder for GenericEnc { task_encoder::encoder_cache!(GenericEnc); diff --git a/prusti-encoder/src/encoders/local_def.rs b/prusti-encoder/src/encoders/local_def.rs index 4f75176976b..313af9ebae2 100644 --- a/prusti-encoder/src/encoders/local_def.rs +++ b/prusti-encoder/src/encoders/local_def.rs @@ -16,6 +16,9 @@ pub struct MirLocalDefEncOutput<'vir> { } pub type MirLocalDefEncError = (); + +impl<'vir> task_encoder::Optimizable for MirLocalDefEncOutput<'vir> {} + #[derive(Clone, Copy)] pub struct LocalDef<'vir> { pub local: vir::Local<'vir>, diff --git a/prusti-encoder/src/encoders/mir_builtin.rs b/prusti-encoder/src/encoders/mir_builtin.rs index cedb865a721..90128a2581d 100644 --- a/prusti-encoder/src/encoders/mir_builtin.rs +++ b/prusti-encoder/src/encoders/mir_builtin.rs @@ -33,6 +33,9 @@ pub struct MirBuiltinEncOutput<'vir> { pub function: vir::Function<'vir>, } +impl<'vir> task_encoder::Optimizable for MirBuiltinEncOutput<'vir> {} + + use crate::encoders::SnapshotEnc; impl TaskEncoder for MirBuiltinEnc { diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index b6fea2d5d6f..6a4482a6106 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -12,7 +12,7 @@ use task_encoder::{ TaskEncoder, TaskEncoderDependencies, }; -use vir::{MethodIdent, UnknownArity, CallableIdent}; +use vir::{with_vcx, CallableIdent, MethodIdent, Optimizable, UnknownArity}; pub struct MirImpureEnc; @@ -32,6 +32,15 @@ pub struct MirImpureEncOutput<'vir> { pub method: vir::Method<'vir>, } +impl<'vir> task_encoder::Optimizable for MirImpureEncOutput<'vir> { + fn optimize(self) -> Self { + let method = self.method.optimize(); + let method = with_vcx(|vcx| vcx.alloc(method)); + MirImpureEncOutput { method } + } +} + + use crate::encoders::{PredicateEnc, ConstEnc, MirBuiltinEnc, MirFunctionEnc, MirLocalDefEnc, MirSpecEnc}; const ENCODE_REACH_BB: bool = false; @@ -391,7 +400,7 @@ impl<'tcx, 'vir, 'enc> EncVisitor<'tcx, 'vir, 'enc> { ty_out.ref_to_snap.apply(self.vcx, [self.encode_place(Place::from(source))]) } mir::Operand::Constant(box constant) => - self.deps.require_local::((constant.literal, 0, self.def_id)).unwrap() + self.deps.require_local::((constant.literal, 0, self.def_id)).unwrap().0 } } @@ -409,7 +418,7 @@ impl<'tcx, 'vir, 'enc> EncVisitor<'tcx, 'vir, 'enc> { } mir::Operand::Constant(box constant) => { let ty_out = self.deps.require_ref::(ty).unwrap(); - let constant = self.deps.require_local::((constant.literal, 0, self.def_id)).unwrap(); + let constant = self.deps.require_local::((constant.literal, 0, self.def_id)).unwrap().0; (constant, ty_out) } }; diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index 9863c9c9b0e..c42a46c7f08 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -31,6 +31,10 @@ pub struct MirPureEncOutput<'vir> { pub expr: ExprRet<'vir>, } +impl<'vir> task_encoder::Optimizable for MirPureEncOutput<'vir> {} + + + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum PureKind { Closure, @@ -624,7 +628,7 @@ impl<'tcx, 'vir: 'enc, 'enc> Enc<'tcx, 'vir, 'enc> mir::Operand::Copy(place) | mir::Operand::Move(place) => self.encode_place(curr_ver, place), mir::Operand::Constant(box constant) => - self.deps.require_local::((constant.literal, self.encoding_depth, self.def_id)).unwrap().lift(), + self.deps.require_local::((constant.literal, self.encoding_depth, self.def_id)).unwrap().0.lift(), } } diff --git a/prusti-encoder/src/encoders/mir_pure_function.rs b/prusti-encoder/src/encoders/mir_pure_function.rs index 2606fce0cea..df333098a5d 100644 --- a/prusti-encoder/src/encoders/mir_pure_function.rs +++ b/prusti-encoder/src/encoders/mir_pure_function.rs @@ -1,7 +1,7 @@ use prusti_rustc_interface::{middle::{mir, ty}, span::def_id::DefId}; use task_encoder::{TaskEncoder, TaskEncoderDependencies}; -use vir::{Reify, FunctionIdent, UnknownArity, CallableIdent}; +use vir::{CallableIdent, FunctionIdent, Optimizable, Reify, UnknownArity}; use crate::encoders::{ MirPureEnc, MirPureEncTask, mir_pure::PureKind, MirSpecEnc, MirLocalDefEnc, @@ -28,6 +28,15 @@ pub struct MirFunctionEncOutput<'vir> { pub function: vir::Function<'vir>, } +impl<'vir> task_encoder::Optimizable for MirFunctionEncOutput<'vir> { + fn optimize(self) -> Self { + let function = self.function.optimize(); + let function = vir::with_vcx(|vcx| vcx.alloc(function)); + + MirFunctionEncOutput { function } + } +} + impl TaskEncoder for MirFunctionEnc { task_encoder::encoder_cache!(MirFunctionEnc); diff --git a/prusti-encoder/src/encoders/pure/spec.rs b/prusti-encoder/src/encoders/pure/spec.rs index 0f67d088446..130b673ada7 100644 --- a/prusti-encoder/src/encoders/pure/spec.rs +++ b/prusti-encoder/src/encoders/pure/spec.rs @@ -14,6 +14,9 @@ pub struct MirSpecEncOutput<'vir> { pub post_args: &'vir [vir::Expr<'vir>], } +impl<'vir> task_encoder::Optimizable for MirSpecEncOutput<'vir> {} + + impl TaskEncoder for MirSpecEnc { task_encoder::encoder_cache!(MirSpecEnc); diff --git a/prusti-encoder/src/encoders/spec.rs b/prusti-encoder/src/encoders/spec.rs index e0418dc38c0..dbc52c235e0 100644 --- a/prusti-encoder/src/encoders/spec.rs +++ b/prusti-encoder/src/encoders/spec.rs @@ -4,8 +4,7 @@ use prusti_rustc_interface::{ }; use prusti_interface::specs::typed::{DefSpecificationMap, ProcedureSpecification}; use task_encoder::{ - TaskEncoder, - TaskEncoderDependencies, + Optimizable, TaskEncoder, TaskEncoderDependencies }; pub struct SpecEnc; @@ -19,6 +18,8 @@ pub struct SpecEncOutput<'vir> { pub posts: &'vir [DefId], } +impl<'vir> Optimizable for SpecEncOutput<'vir> {} + use std::cell::RefCell; thread_local! { static DEF_SPEC_MAP: RefCell> = RefCell::new(Default::default()); diff --git a/prusti-encoder/src/encoders/type/domain.rs b/prusti-encoder/src/encoders/type/domain.rs index a3b71398c7f..bb3932b42a8 100644 --- a/prusti-encoder/src/encoders/type/domain.rs +++ b/prusti-encoder/src/encoders/type/domain.rs @@ -77,9 +77,21 @@ pub struct DomainEncOutputRef<'vir> { } impl<'vir> task_encoder::OutputRefAny for DomainEncOutputRef<'vir> {} + +#[derive(Clone)] +pub struct DomainEncOutput<'vir>(pub vir::Domain<'vir>); + +impl<'vir> task_encoder::Optimizable for DomainEncOutput<'vir> {} + +impl<'vir> From> for DomainEncOutput<'vir> { + fn from(value: vir::Domain<'vir>) -> Self { + DomainEncOutput(value) + } +} + use crate::encoders::SnapshotEnc; -pub fn all_outputs<'vir>() -> Vec> { +pub fn all_outputs<'vir>() -> Vec> { DomainEnc::all_outputs() } @@ -90,7 +102,7 @@ impl TaskEncoder for DomainEnc { type OutputRef<'vir> = DomainEncOutputRef<'vir>; type OutputFullDependency<'vir> = DomainEncSpecifics<'vir>; - type OutputFullLocal<'vir> = vir::Domain<'vir>; + type OutputFullLocal<'vir> = DomainEncOutput<'vir>; //type OutputFullDependency<'vir> = DomainEncOutputDep<'vir>; type EncodingError = (); @@ -109,7 +121,7 @@ impl TaskEncoder for DomainEnc { Self::EncodingError, Option>, )> { - vir::with_vcx(|vcx| match task_key.kind() { + (vir::with_vcx(|vcx| match task_key.kind() { TyKind::Bool | TyKind::Char | TyKind::Int(_) | TyKind::Uint(_) | TyKind::Float(_) => { let (base_name, prim_type) = match task_key.kind() { TyKind::Bool => (String::from("Bool"), &vir::TypeData::Bool), @@ -197,7 +209,7 @@ impl TaskEncoder for DomainEnc { Ok((enc.finalize(), specifics)) } kind => todo!("{kind:?}"), - }) + })) } } @@ -531,13 +543,13 @@ impl<'vir, 'tcx> DomainEncData<'vir, 'tcx> { domain: self.domain, } } - fn finalize(self) -> vir::Domain<'vir> { + fn finalize(self) -> DomainEncOutput<'vir> { self.vcx.mk_domain( self.domain.name(), self.domain.arity().args(), self.vcx.alloc_slice(&self.axioms), self.vcx.alloc_slice(&self.functions), - ) + ).into() } } diff --git a/prusti-encoder/src/encoders/type/predicate.rs b/prusti-encoder/src/encoders/type/predicate.rs index b4ca2afeda3..b7063de5458 100644 --- a/prusti-encoder/src/encoders/type/predicate.rs +++ b/prusti-encoder/src/encoders/type/predicate.rs @@ -146,6 +146,9 @@ pub struct PredicateEncOutput<'vir> { pub method_assign: vir::Method<'vir>, } +impl<'vir> task_encoder::Optimizable for PredicateEncOutput<'vir> {} + + use super::{snapshot::SnapshotEnc, domain::{DomainDataPrim, DomainDataStruct, DomainDataEnum, DiscrBounds}}; impl TaskEncoder for PredicateEnc { diff --git a/prusti-encoder/src/encoders/type/snapshot.rs b/prusti-encoder/src/encoders/type/snapshot.rs index 5fc2a035abf..96b494efc44 100644 --- a/prusti-encoder/src/encoders/type/snapshot.rs +++ b/prusti-encoder/src/encoders/type/snapshot.rs @@ -25,6 +25,9 @@ pub struct SnapshotEncOutput<'vir> { pub specifics: DomainEncSpecifics<'vir>, } + +impl<'vir> task_encoder::Optimizable for SnapshotEncOutput<'vir> {} + use super::domain::{DomainEnc, DomainEncSpecifics}; impl TaskEncoder for SnapshotEnc { diff --git a/prusti-encoder/src/encoders/type/viper_tuple.rs b/prusti-encoder/src/encoders/type/viper_tuple.rs index e004a6811e6..d196937277f 100644 --- a/prusti-encoder/src/encoders/type/viper_tuple.rs +++ b/prusti-encoder/src/encoders/type/viper_tuple.rs @@ -15,6 +15,9 @@ pub struct ViperTupleEncOutput<'vir> { tuple: Option>, } +impl<'vir> task_encoder::Optimizable for ViperTupleEncOutput<'vir> {} + + impl<'vir> ViperTupleEncOutput<'vir> { pub fn mk_cons<'tcx, Curr, Next>( &self, diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index b9ebdc71804..c8d9264b758 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -14,6 +14,26 @@ use prusti_rustc_interface::{ hir, }; + +const ENABLE_OPTIMIZATION : bool = true; + +// Wrapper Trait for task_encoder::Optimizable to allow toggling of optimization +// TODO: replace with config +trait MaybeOptimize { + fn optimize(self) -> Self; +} + +impl MaybeOptimize for T where T : task_encoder::Optimizable { + fn optimize(self) -> Self { + if ENABLE_OPTIMIZATION { + task_encoder::Optimizable::optimize(self) + } + else { + self + } + } +} + pub fn test_entrypoint<'tcx>( tcx: ty::TyCtxt<'tcx>, body: EnvBody<'tcx>, @@ -63,34 +83,34 @@ pub fn test_entrypoint<'tcx>( let mut viper_code = String::new(); header(&mut viper_code, "methods"); - for output in crate::encoders::MirImpureEnc::all_outputs() { + for output in crate::encoders::MirImpureEnc::all_outputs().optimize() { viper_code.push_str(&format!("{:?}\n", output.method)); } header(&mut viper_code, "functions"); - for output in crate::encoders::MirFunctionEnc::all_outputs() { + for output in crate::encoders::MirFunctionEnc::all_outputs().optimize() { viper_code.push_str(&format!("{:?}\n", output.function)); } header(&mut viper_code, "MIR builtins"); - for output in crate::encoders::MirBuiltinEnc::all_outputs() { + for output in crate::encoders::MirBuiltinEnc::all_outputs().optimize() { viper_code.push_str(&format!("{:?}\n", output.function)); } header(&mut viper_code, "generics"); - for output in crate::encoders::GenericEnc::all_outputs() { + for output in crate::encoders::GenericEnc::all_outputs().optimize() { viper_code.push_str(&format!("{:?}\n", output.snapshot_param)); viper_code.push_str(&format!("{:?}\n", output.predicate_param)); viper_code.push_str(&format!("{:?}\n", output.domain_type)); } header(&mut viper_code, "snapshots"); - for output in crate::encoders::DomainEnc_all_outputs() { - viper_code.push_str(&format!("{:?}\n", output)); + for output in crate::encoders::DomainEnc_all_outputs().optimize() { + viper_code.push_str(&format!("{:?}\n", output.0)); } header(&mut viper_code, "types"); - for output in crate::encoders::PredicateEnc::all_outputs() { + for output in crate::encoders::PredicateEnc::all_outputs().optimize() { for field in output.fields { viper_code.push_str(&format!("{:?}", field)); } diff --git a/task-encoder/src/lib.rs b/task-encoder/src/lib.rs index c2fab0eba50..d9693935946 100644 --- a/task-encoder/src/lib.rs +++ b/task-encoder/src/lib.rs @@ -6,6 +6,18 @@ use std::cell::RefCell; pub trait OutputRefAny {} impl OutputRefAny for () {} + +pub trait Optimizable: Sized { + fn optimize(self) -> Self { + self + } +} + +impl Optimizable for Vec where T: Optimizable { + fn optimize(self) -> Self { + self.into_iter().map(|e|e.optimize()).collect() + } +} pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> { // None, // indicated by absence in the cache @@ -177,7 +189,7 @@ pub trait TaskEncoder { /// Fully encoded output for this task. When encoding items which can be /// dependencies (such as methods), this output should only be emitted in /// one Viper program. - type OutputFullLocal<'vir>: Clone; + type OutputFullLocal<'vir>: Clone + Optimizable; /// Fully encoded output for this task for dependents. When encoding items /// which can be dependencies (such as methods), this output should be diff --git a/vir/src/data.rs b/vir/src/data.rs index efa01e78919..6f22d5352a4 100644 --- a/vir/src/data.rs +++ b/vir/src/data.rs @@ -88,6 +88,7 @@ pub enum ConstData { Null, } +#[derive(PartialEq, Eq)] pub enum TypeData<'vir> { Int { bit_width: u8, @@ -102,12 +103,12 @@ pub enum TypeData<'vir> { Unsupported(UnsupportedType<'vir>) } -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] pub struct UnsupportedType<'vir> { pub name: &'vir str, } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct DomainParamData<'vir> { pub name: &'vir str, // TODO: identifiers } diff --git a/vir/src/folder.rs b/vir/src/folder.rs new file mode 100644 index 00000000000..c8fb395ecd2 --- /dev/null +++ b/vir/src/folder.rs @@ -0,0 +1,248 @@ +use crate::{ + AccFieldGenData, BinOpGenData, BinOpKind, ExprGen, ExprGenData, ExprKindGenData, Field, + ForallGenData, FuncAppGenData, LetGenData, LocalDecl, PredicateAppGen, PredicateAppGenData, + TernaryGenData, Type, UnOpGenData, UnOpKind, UnfoldingGenData, VirCtxt, +}; + +fn default_fold_expr<'vir, Cur, Next, T: ExprFolder<'vir, Cur, Next>>( + this: &mut T, + e: &'vir crate::ExprGenData<'vir, Cur, Next>, +) -> &'vir crate::ExprGenData<'vir, Cur, Next> { + match e.kind { + ExprKindGenData::Local(local) => this.fold_local(local), + ExprKindGenData::Field(recv, field) => this.fold_field(recv, field), + ExprKindGenData::Old(expr) => this.fold_old(expr), + ExprKindGenData::Const(value) => this.fold_const(value), + ExprKindGenData::Result => this.fold_result(), + ExprKindGenData::AccField(AccFieldGenData { recv, field, perm }) => { + this.fold_acc_field(recv, field, *perm) + } + ExprKindGenData::Unfolding(UnfoldingGenData { target, expr }) => { + this.fold_unfolding(target, expr) + } + ExprKindGenData::UnOp(UnOpGenData { kind, expr }) => this.fold_unop(*kind, expr), + ExprKindGenData::BinOp(BinOpGenData { kind, lhs, rhs }) => this.fold_binop(*kind, lhs, rhs), + ExprKindGenData::Ternary(TernaryGenData { cond, then, else_ }) => { + this.fold_ternary(cond, then, else_) + } + ExprKindGenData::Forall(ForallGenData { + qvars, + triggers, + body, + }) => this.fold_forall(qvars, triggers, body), + ExprKindGenData::Let(LetGenData { name, val, expr }) => this.fold_let(name, val, expr), + ExprKindGenData::FuncApp(FuncAppGenData { + target, + args, + result_ty, + }) => this.fold_func_app(target, args, *result_ty), + ExprKindGenData::PredicateApp(PredicateAppGenData { target, args, perm }) => { + this.fold_predicate_app(target, args, *perm) + } + ExprKindGenData::Lazy(name, func) => this.fold_lazy(name, func), + ExprKindGenData::Todo(msg) => this.fold_todo(msg), + } +} + +pub trait ExprFolder<'vir, Cur, Next>: Sized { + fn fold(&mut self, e: crate::ExprGen<'vir, Cur, Next>) -> crate::ExprGen<'vir, Cur, Next> { + default_fold_expr(self, e) + } + + fn fold_option( + &mut self, + e: Option>, + ) -> Option> { + e.map(|i| self.fold(i)) + } + + fn fold_slice( + &mut self, + s: &'vir [ExprGen<'vir, Cur, Next>], + ) -> &'vir [ExprGen<'vir, Cur, Next>] { + let vec = s.iter().map(|e| self.fold(e)).collect::>(); + + crate::with_vcx(move |vcx| vcx.alloc_slice(&vec)) + } + + fn fold_slice_slice( + &mut self, + s: &'vir [&'vir [ExprGen<'vir, Cur, Next>]], + ) -> &'vir [&'vir [ExprGen<'vir, Cur, Next>]] { + let vec = s.iter().map(|e| self.fold_slice(e)).collect::>(); + + crate::with_vcx(move |vcx| vcx.alloc_slice(&vec)) + } + + fn fold_local(&mut self, local: crate::Local<'vir>) -> crate::ExprGen<'vir, Cur, Next> { + crate::with_vcx(move |vcx| vcx.mk_local_ex_local(local)) + } + + fn fold_field( + &mut self, + recv: crate::ExprGen<'vir, Cur, Next>, + field: crate::Field<'vir>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let recv = self.fold(recv); + crate::with_vcx(move |vcx| vcx.mk_field_expr(recv, field)) + } + + fn fold_old( + &mut self, + expr: crate::ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let expr = self.fold(expr); + + crate::with_vcx(move |vcx| vcx.mk_old_expr(expr)) + } + + fn fold_const(&mut self, value: crate::Const<'vir>) -> crate::ExprGen<'vir, Cur, Next> { + crate::with_vcx(move |vcx| { + vcx.alloc(ExprGenData { + kind: vcx.alloc(ExprKindGenData::Const(value)), + }) + }) + } + + fn fold_result(&mut self) -> crate::ExprGen<'vir, Cur, Next> { + crate::with_vcx(move |vcx| { + vcx.alloc(ExprGenData { + kind: vcx.alloc(ExprKindGenData::Result), + }) + }) + } + + fn fold_acc_field( + &mut self, + recv: ExprGen<'vir, Cur, Next>, + field: Field<'vir>, + perm: Option>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let recv = self.fold(recv); + let perm = self.fold_option(perm); + + crate::with_vcx(move |vcx| vcx.mk_acc_field_expr(recv, field, perm)) + } + + fn fold_predicate_app( + &mut self, + target: &'vir str, // TODO: identifiers + args: &'vir [ExprGen<'vir, Cur, Next>], + perm: Option>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let args = self.fold_slice(args); + let perm = self.fold_option(perm); + + crate::with_vcx(move |vcx| { + let pred_app = vcx.alloc(PredicateAppGenData { target, args, perm }); + + vcx.mk_predicate_app_expr(pred_app) + }) + } + + fn fold_unfolding( + &mut self, + PredicateAppGenData { target, args, perm }: PredicateAppGen<'vir, Cur, Next>, + expr: ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let expr = self.fold(expr); + + let args = self.fold_slice(args); + let perm = self.fold_option(*perm); + + crate::with_vcx(move |vcx| { + let target = vcx.alloc(PredicateAppGenData { target, args, perm }); + vcx.mk_unfolding_expr(target, expr) + }) + } + + fn fold_unop( + &mut self, + kind: UnOpKind, + expr: ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let expr = self.fold(expr); + crate::with_vcx(move |vcx| vcx.mk_unary_op_expr(kind, expr)) + } + + fn fold_binop( + &mut self, + kind: BinOpKind, + lhs: ExprGen<'vir, Cur, Next>, + rhs: ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let lhs = self.fold(lhs); + let rhs = self.fold(rhs); + + crate::with_vcx(move |vcx| vcx.mk_bin_op_expr(kind, lhs, rhs)) + } + + fn fold_ternary( + &mut self, + cond: ExprGen<'vir, Cur, Next>, + then: ExprGen<'vir, Cur, Next>, + else_: ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let cond = self.fold(cond); + let then = self.fold(then); + let else_ = self.fold(else_); + + crate::with_vcx(move |vcx| vcx.mk_ternary_expr(cond, then, else_)) + } + + fn fold_forall( + &mut self, + qvars: &'vir [LocalDecl<'vir>], + triggers: &'vir [&'vir [ExprGen<'vir, Cur, Next>]], + body: ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let triggers = self.fold_slice_slice(triggers); + let body = self.fold(body); + + crate::with_vcx(move |vcx| vcx.mk_forall_expr(qvars, triggers, body)) + } + + fn fold_let( + &mut self, + name: &'vir str, + val: ExprGen<'vir, Cur, Next>, + expr: ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let val = self.fold(val); + let expr = self.fold(expr); + + crate::with_vcx(move |vcx| vcx.mk_let_expr(name, val, expr)) + } + + fn fold_func_app( + &mut self, + target: &'vir str, + src_args: &'vir [ExprGen<'vir, Cur, Next>], + result_ty: Option>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let src_args = self.fold_slice(src_args); + + crate::with_vcx(move |vcx| vcx.mk_func_app(target, src_args, result_ty)) + } + + fn fold_todo(&mut self, msg: &'vir str) -> crate::ExprGen<'vir, Cur, Next> { + crate::with_vcx(move |vcx| vcx.mk_todo_expr(msg)) + } + + fn fold_lazy( + &mut self, + name: &'vir str, + func: &'vir Box Fn(&'vir VirCtxt<'a>, Cur) -> Next + 'vir>, + ) -> crate::ExprGen<'vir, Cur, Next> { + crate::with_vcx(move |vcx| { + vcx.mk_lazy_expr( + name, + Box::new(move |ctx, c| { + let r = func(ctx, c); + // TODO + r + }), + ) + }) + } +} diff --git a/vir/src/gendata.rs b/vir/src/gendata.rs index babaaaa5991..77261d0a02f 100644 --- a/vir/src/gendata.rs +++ b/vir/src/gendata.rs @@ -89,6 +89,7 @@ pub struct ExprGenData<'vir, Curr: 'vir, Next: 'vir>{ pub kind: ExprKindGen<'vir, Curr, Next> } + pub enum ExprKindGenData<'vir, Curr: 'vir, Next: 'vir> { Local(Local<'vir>), Field(ExprGen<'vir, Curr, Next>, Field<'vir>), // TODO: FieldApp? @@ -157,6 +158,19 @@ pub struct FunctionGenData<'vir, Curr, Next> { pub(crate) expr: Option>, } +impl<'vir, Curr, Next> crate::Optimizable for FunctionGenData<'vir, Curr, Next> { + fn optimize(&self) -> Self { + FunctionGenData { + name: self. name, + args: self.args, + ret: self.ret, + expr: self.expr.optimize(), + pres: self.pres.optimize(), + posts: self.posts.optimize(), + } + } +} + // TODO: why is this called "pure"? #[derive(Reify)] pub struct PureAssignGenData<'vir, Curr, Next> { @@ -225,6 +239,20 @@ pub struct MethodGenData<'vir, Curr, Next> { pub(crate) blocks: Option<&'vir [CfgBlockGen<'vir, Curr, Next>]>, // first one is the entrypoint } + +impl<'vir, Curr, Next> crate::Optimizable for MethodGenData<'vir, Curr, Next> { + fn optimize(&self) -> Self { + MethodGenData { + name: self. name, + args: self.args, + rets: self.rets, + blocks: self.blocks, + pres: self.pres.optimize(), + posts: self.posts.optimize(), + } + } +} + #[derive(Debug, Reify)] pub struct ProgramGenData<'vir, Curr, Next> { #[reify_copy] pub(crate) fields: &'vir [Field<'vir>], diff --git a/vir/src/lib.rs b/vir/src/lib.rs index d2afbd55406..06ce383c898 100644 --- a/vir/src/lib.rs +++ b/vir/src/lib.rs @@ -11,14 +11,18 @@ mod macros; mod refs; mod reify; mod callable_idents; +mod folder; +mod opt; +pub use callable_idents::*; pub use context::*; pub use data::*; +pub use folder::*; pub use gendata::*; pub use genrefs::*; +pub use opt::*; pub use refs::*; pub use reify::*; -pub use callable_idents::*; // for all arena-allocated types, there are two type definitions: one with // a `Data` suffix, containing the actual data; and one without the suffix, diff --git a/vir/src/opt.rs b/vir/src/opt.rs new file mode 100644 index 00000000000..5e5ee877981 --- /dev/null +++ b/vir/src/opt.rs @@ -0,0 +1,257 @@ +use std::collections::HashMap; + +use crate::{ExprFolder, ExprGen, ExprGenData}; + +pub trait Optimizable: Sized { + fn optimize(&self) -> Self; +} + +impl<'vir, T> Optimizable for Option<&'vir T> +where + T: Optimizable, +{ + fn optimize(&self) -> Self { + self.map(|inner| { + let o = inner.optimize(); + crate::with_vcx(move |vcx| vcx.alloc(o)) + }) + } +} + +impl<'vir, T> Optimizable for &'vir [&T] +where + T: Optimizable, +{ + fn optimize(&self) -> Self { + let v = self + .iter() + .map(|e| { + let e = e.optimize(); + crate::with_vcx(|vcx| vcx.alloc(e)) + }) + .collect::>(); + crate::with_vcx(move |vcx| vcx.alloc_slice(&v)) + } +} + +impl<'vir, Curr, Next> crate::Optimizable for ExprGenData<'vir, Curr, Next> { + fn optimize(&self) -> Self { + let r = crate::with_vcx(move |vcx| vcx.alloc(ExprGenData { kind: self.kind })); + let s1 = (VariableOptimizerFolder { + rename: Default::default(), + }) + .fold(r); + + let s2 = EveryThingInliner::new().fold(s1); + let s3 = BoolOptimizerFolder.fold(s2); + + ExprGenData { kind: s3.kind } + } +} + +struct BoolOptimizerFolder; + +impl<'vir, Cur, Next> ExprFolder<'vir, Cur, Next> for BoolOptimizerFolder { + // transforms `a == true` into `a` and `a == false` into `!a` + fn fold_binop( + &mut self, + kind: crate::BinOpKind, + lhs: ExprGen<'vir, Cur, Next>, + rhs: ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let lhs = self.fold(lhs); + let rhs = self.fold(rhs); + + if let crate::BinOpKind::CmpEq = kind { + if let crate::ExprKindGenData::Const(crate::ConstData::Bool(b)) = rhs.kind { + return if *b { + // case lhs == true + lhs + } else { + // case lhs == false + crate::with_vcx(move |vcx| vcx.mk_unary_op_expr(crate::UnOpKind::Not, lhs)) + }; + } + } + + crate::with_vcx(move |vcx| vcx.mk_bin_op_expr(kind, lhs, rhs)) + } + + // Transforms `c? true : false` into `c` + fn fold_ternary( + &mut self, + cond: ExprGen<'vir, Cur, Next>, + then: ExprGen<'vir, Cur, Next>, + else_: ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let cond = self.fold(cond); + let then = self.fold(then); + let else_ = self.fold(else_); + + if let ( + crate::ExprKindGenData::Const(crate::ConstData::Bool(true)), + crate::ExprKindGenData::Const(crate::ConstData::Bool(false)), + ) = (then.kind, else_.kind) + { + return cond; + } + + crate::with_vcx(move |vcx| vcx.mk_ternary_expr(cond, then, else_)) + } +} + +pub(crate) struct EveryThingInliner<'vir, Cur, Next> { + rename: HashMap<&'vir str, crate::ExprGen<'vir, Cur, Next>>, +} + +impl<'vir, Cur, Next> EveryThingInliner<'vir, Cur, Next> { + fn new() -> Self { + Self { + rename: HashMap::new(), + } + } +} + +impl<'vir, Cur, Next> ExprFolder<'vir, Cur, Next> for EveryThingInliner<'vir, Cur, Next> { + fn fold_let( + &mut self, + name: &'vir str, + val: ExprGen<'vir, Cur, Next>, + expr: ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let val = self.fold(val); + + self.rename.insert(name, val); + + let expr = self.fold(expr); + + expr + } + + fn fold_local(&mut self, local: crate::Local<'vir>) -> crate::ExprGen<'vir, Cur, Next> { + let lcl = crate::with_vcx(move |vcx| vcx.mk_local_ex_local(local)); + + self.rename.get(local.name).map(|e| *e).unwrap_or(lcl) + } + + // Transforms `C ? f(a) : f(b)` into `f(C? a : b)` + fn fold_ternary( + &mut self, + cond: ExprGen<'vir, Cur, Next>, + then: ExprGen<'vir, Cur, Next>, + else_: ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let cond = self.fold(cond); + let then = self.fold(then); + let else_ = self.fold(else_); + + if let ( + crate::ExprKindGenData::FuncApp(then_app), + crate::ExprKindGenData::FuncApp(else_app), + ) = (then.kind, else_.kind) + { + if then_app.args.len() == 1 + && else_app.args.len() == 1 + && else_app.target == then_app.target + && else_app.result_ty == then_app.result_ty + { + return crate::with_vcx(move |vcx| { + vcx.mk_func_app( + then_app.target, + &[vcx.mk_ternary_expr(cond, then_app.args[0], else_app.args[0])], + then_app.result_ty, + ) + }); + } + } + + crate::with_vcx(move |vcx| vcx.mk_ternary_expr(cond, then, else_)) + } + + // transforms `foo_read_x(foo_cons(a_1, ... a_n))` into a_x + fn fold_func_app( + &mut self, + target: &'vir str, + src_args: &'vir [ExprGen<'vir, Cur, Next>], + result_ty: Option>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let src_args = self.fold_slice(src_args); + + // Hacky way to do read of cons: + if src_args.len() == 1 { + if let crate::ExprKindGenData::FuncApp(innerfuncapp) = src_args[0].kind { + if let Some((start, "cons")) = innerfuncapp.target.rsplit_once("_") { + if let Some((_, read_nr)) = target.rsplit_once("_") { + if target.ends_with(&format!("read_{}", read_nr)) + && target.starts_with(start) + { + if let Ok(read_nr) = read_nr.parse::() { + return innerfuncapp.args[read_nr]; + } else { + println!("ERROR: Not a number: {} {}", innerfuncapp.target, target); + } + } + } + } + } + } + + crate::with_vcx(move |vcx| vcx.mk_func_app(target, src_args, result_ty)) + } +} + +pub(crate) struct VariableOptimizerFolder<'vir> { + rename: HashMap, +} + +impl<'vir, Cur, Next> ExprFolder<'vir, Cur, Next> for VariableOptimizerFolder<'vir> { + fn fold_local(&mut self, local: crate::Local<'vir>) -> crate::ExprGen<'vir, Cur, Next> { + let nam = self + .rename + .get(local.name) + .map(|e| *e) + .unwrap_or(local.name); + crate::with_vcx(move |vcx| vcx.mk_local_ex(&nam, local.ty)) + } + + fn fold_old( + &mut self, + expr: crate::ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + // Do not go inside of old + crate::with_vcx(|vcx| vcx.mk_old_expr(expr)) + } + + fn fold_let( + &mut self, + name: &'vir str, + val: ExprGen<'vir, Cur, Next>, + expr: ExprGen<'vir, Cur, Next>, + ) -> crate::ExprGen<'vir, Cur, Next> { + let val = self.fold(val); + + match val.kind { + // let name = loc.name + crate::ExprKindGenData::Local(loc) => { + let t = self + .rename + .get(loc.name) + .map(|e| e.to_owned()) + .unwrap_or(loc.name); + assert!(self.rename.insert(name.to_string(), t).is_none()); + return self.fold(expr); + } + _ => {} + } + + let expr = self.fold(expr); + + if let crate::ExprKindGenData::Local(inner_local) = expr.kind { + if inner_local.name == name { + // if we encounter the case `let X = val in X` then just return `val` + return val; + } + } + crate::with_vcx(move |vcx| vcx.mk_let_expr(name, val, expr)) + } +}