diff --git a/prusti-encoder/src/encoder_traits/pure_function_enc.rs b/prusti-encoder/src/encoder_traits/pure_function_enc.rs index b1b5582dee9..0850a427bce 100644 --- a/prusti-encoder/src/encoder_traits/pure_function_enc.rs +++ b/prusti-encoder/src/encoder_traits/pure_function_enc.rs @@ -139,8 +139,7 @@ where param_env: vcx.tcx().param_env(def_id), substs, caller_def_id, - })? - .expr; + })?; let expr = expr.reify(vcx, (def_id, spec.pre_args)); assert!( expr.ty() == return_type.snapshot, diff --git a/prusti-encoder/src/encoders/const.rs b/prusti-encoder/src/encoders/const.rs index a0c47e1445f..8e22049053d 100644 --- a/prusti-encoder/src/encoders/const.rs +++ b/prusti-encoder/src/encoders/const.rs @@ -107,8 +107,7 @@ impl TaskEncoder for ConstEnc { kind: PureKind::Constant(uneval.promoted.unwrap()), caller_def_id: Some(def_id), }; - let expr = deps.require_local::(task)?.expr; - use vir::Reify; + let expr = deps.require_local::(task)?; Ok(expr.reify(vcx, (uneval.def, &[]))) })?, mir::Const::Ty(_, _) => todo!("ConstantKind::Ty"), diff --git a/prusti-encoder/src/encoders/local_def.rs b/prusti-encoder/src/encoders/local_def.rs index 0120900be2b..a4802ac6082 100644 --- a/prusti-encoder/src/encoders/local_def.rs +++ b/prusti-encoder/src/encoders/local_def.rs @@ -73,8 +73,13 @@ impl TaskEncoder for MirLocalDefEnc { } } + let trusted = crate::encoders::with_proc_spec(def_id, |def_spec| { + def_spec.trusted.extract_inherit().unwrap_or_default() + }) + .unwrap_or_default(); vir::with_vcx(|vcx| { - let data = if let Some(local_def_id) = def_id.as_local() { + let local_def_id = def_id.as_local().filter(|_| !trusted); + let data = if let Some(local_def_id) = local_def_id { let body = vcx .body_mut() .get_impure_fn_body(local_def_id, substs, caller_def_id); diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index a0eeeb85f1e..f3b0514efd3 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -53,7 +53,14 @@ type ExprRet<'vir> = vir::ExprGen<'vir, ExprInput<'vir>, vir::ExprKind<'vir>>; pub struct MirPureEncOutput<'vir> { // TODO: is this a good place for argument types? //pub arg_tys: &'vir [Type<'vir>], - pub expr: ExprRet<'vir>, + expr: ExprRet<'vir>, +} + +impl<'vir> MirPureEncOutput<'vir> { + pub fn reify(self, vcx: &'vir vir::VirCtxt<'vir>, input: ExprInput<'vir>) -> vir::Expr<'vir> { + use vir::{Reify, Optimizable}; + self.expr.reify(vcx, input)//.optimize() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -1060,7 +1067,6 @@ impl<'vir: 'enc, 'enc> Enc<'vir, 'enc> { caller_def_id: Some(self.def_id), }) .unwrap() - .expr // arguments to the closure are // - the closure itself // - the qvars diff --git a/prusti-encoder/src/encoders/pure/spec.rs b/prusti-encoder/src/encoders/pure/spec.rs index 77620c44cb9..4487bf3dd7f 100644 --- a/prusti-encoder/src/encoders/pure/spec.rs +++ b/prusti-encoder/src/encoders/pure/spec.rs @@ -100,8 +100,7 @@ impl TaskEncoder for MirSpecEnc { caller_def_id: Some(def_id), }, ) - .unwrap() - .expr; + .unwrap(); let expr = expr.reify(vcx, (*spec_def_id, pre_args)); let span = vcx.tcx().def_span(spec_def_id); vcx.with_span(span, |vcx| to_bool.apply(vcx, [expr])) @@ -142,8 +141,7 @@ impl TaskEncoder for MirSpecEnc { caller_def_id: Some(def_id), }, ) - .unwrap() - .expr; + .unwrap(); let expr = expr.reify(vcx, (*spec_def_id, post_args)); to_bool.apply(vcx, [expr]) }) @@ -177,7 +175,6 @@ impl TaskEncoder for MirSpecEnc { }, ) .unwrap() - .expr }); let rhs_expr = deps .require_local::( @@ -191,8 +188,7 @@ impl TaskEncoder for MirSpecEnc { caller_def_id: Some(def_id), }, ) - .unwrap() - .expr; + .unwrap(); let lhs_expr = lhs_expr .map(|lhs_expr| lhs_expr.reify(vcx, (lhs_def_id.unwrap(), pledge_args))); let rhs_expr = rhs_expr.reify(vcx, (*rhs_def_id, pledge_args)); diff --git a/vir/src/fold.rs b/vir/src/fold.rs new file mode 100644 index 00000000000..89a3d863ff7 --- /dev/null +++ b/vir/src/fold.rs @@ -0,0 +1,261 @@ +use crate::{debug_info::DebugInfo, *}; + +pub trait ExprFolder<'vir, Curr, Next>: Sized { + fn super_fold(&mut self, e: ExprGen<'vir, Curr, Next>) -> ExprGen<'vir, Curr, Next> { + match e.kind { + ExprKindGenData::Local(local) => self.fold_local(local), + ExprKindGenData::Field(recv, field) => self.fold_field(recv, field), + ExprKindGenData::Old(expr) => self.fold_old(expr), + ExprKindGenData::Const(value) => self.fold_const(value, e.debug_info, e.span), + ExprKindGenData::Result(ty) => self.fold_result(ty, e.debug_info, e.span), + ExprKindGenData::AccField(AccFieldGenData { recv, field, perm }) => { + self.fold_acc_field(recv, field, *perm) + } + ExprKindGenData::Unfolding(UnfoldingGenData { target, expr }) => { + self.fold_unfolding(target, expr) + } + ExprKindGenData::UnOp(UnOpGenData { kind, expr }) => self.fold_unop(*kind, expr), + ExprKindGenData::BinOp(BinOpGenData { kind, lhs, rhs }) => { + self.fold_binop(*kind, lhs, rhs) + } + ExprKindGenData::Ternary(TernaryGenData { cond, then, else_ }) => { + self.fold_ternary(cond, then, else_) + } + ExprKindGenData::Forall(ForallGenData { + qvars, + triggers, + body, + }) => self.fold_forall(qvars, triggers, body), + ExprKindGenData::Let(LetGenData { name, val, expr }) => self.fold_let(name, val, expr), + ExprKindGenData::FuncApp(FuncAppGenData { + target, + args, + result_ty, + }) => self.fold_func_app(target, args, *result_ty), + ExprKindGenData::PredicateApp(PredicateAppGenData { target, args, perm }) => { + self.fold_predicate_app(target, args, *perm) + } + ExprKindGenData::Lazy(lazy) => self.fold_lazy(lazy), + ExprKindGenData::Todo(msg) => self.fold_todo(msg), + ExprKindGenData::Exists(..) => todo!(), + ExprKindGenData::Wand(..) => todo!(), + } + } + + fn fold(&mut self, e: ExprGen<'vir, Curr, Next>) -> ExprGen<'vir, Curr, Next> { + self.super_fold(e) + } + + fn fold_option( + &mut self, + e: Option>, + ) -> Option> { + e.map(|i| self.fold(i)) + } + + fn fold_slice( + &mut self, + s: &'vir [ExprGen<'vir, Curr, Next>], + ) -> &'vir [ExprGen<'vir, Curr, Next>] { + let vec = s.iter().map(|e| self.fold(e)).collect::>(); + + with_vcx(move |vcx| vcx.alloc_slice(&vec)) + } + + fn fold_slice_slice( + &mut self, + s: &'vir [TriggerGen<'vir, Curr, Next>], + ) -> &'vir [TriggerGen<'vir, Curr, Next>] { + with_vcx(move |vcx| { + let vec = s + .iter() + .map(|e| vcx.mk_trigger(self.fold_slice(e.exprs))) + .collect::>(); + vcx.alloc_slice(&vec) + }) + } + + fn fold_local(&mut self, local: Local<'vir>) -> ExprGen<'vir, Curr, Next> { + with_vcx(move |vcx| vcx.mk_local_ex_local(local)) + } + + fn fold_field( + &mut self, + recv: ExprGen<'vir, Curr, Next>, + field: Field<'vir>, + ) -> ExprGen<'vir, Curr, Next> { + let recv = self.fold(recv); + with_vcx(move |vcx| vcx.mk_field_expr(recv, field)) + } + + fn fold_old(&mut self, expr: OldGen<'vir, Curr, Next>) -> ExprGen<'vir, Curr, Next> { + let expr = self.fold(expr.expr); + + with_vcx(move |vcx| vcx.mk_old_expr(expr)) + } + + fn fold_const( + &mut self, + value: Const<'vir>, + debug_info: DebugInfo<'vir>, + span: Option<&'vir VirSpan<'vir>>, + ) -> ExprGen<'vir, Curr, Next> { + with_vcx(move |vcx| { + vcx.alloc(ExprGenData { + kind: vcx.alloc(ExprKindGenData::Const(value)), + debug_info, + span, + }) + }) + } + + fn fold_result( + &mut self, + ty: Type<'vir>, + debug_info: DebugInfo<'vir>, + span: Option<&'vir VirSpan<'vir>>, + ) -> ExprGen<'vir, Curr, Next> { + with_vcx(move |vcx| { + vcx.alloc(ExprGenData { + kind: vcx.alloc(ExprKindGenData::Result(ty)), + debug_info, + span, + }) + }) + } + + fn fold_acc_field( + &mut self, + recv: ExprGen<'vir, Curr, Next>, + field: Field<'vir>, + perm: Option>, + ) -> ExprGen<'vir, Curr, Next> { + let recv = self.fold(recv); + let perm = self.fold_option(perm); + + with_vcx(move |vcx| vcx.mk_acc_field_expr(recv, field, perm)) + } + + fn fold_predicate_app( + &mut self, + target: &'vir str, // TODO: identifiers + args: &'vir [ExprGen<'vir, Curr, Next>], + perm: Option>, + ) -> ExprGen<'vir, Curr, Next> { + let args = self.fold_slice(args); + let perm = self.fold_option(perm); + + with_vcx(move |vcx| { + let pred_app = vcx.alloc(PredicateAppGenData { target, args, perm }); + + vcx.mk_predicate_app_expr(pred_app) + }) + } + + fn fold_unfolding( + &mut self, + pred: PredicateAppGen<'vir, Curr, Next>, + expr: ExprGen<'vir, Curr, Next>, + ) -> ExprGen<'vir, Curr, Next> { + let expr = self.fold(expr); + + let args = self.fold_slice(pred.args); + let perm = self.fold_option(pred.perm); + + with_vcx(move |vcx| { + let target = vcx.alloc(PredicateAppGenData { + target: pred.target, + args, + perm, + }); + vcx.mk_unfolding_expr(target, expr) + }) + } + + fn fold_unop( + &mut self, + kind: UnOpKind, + expr: ExprGen<'vir, Curr, Next>, + ) -> ExprGen<'vir, Curr, Next> { + let expr = self.fold(expr); + with_vcx(move |vcx| vcx.mk_unary_op_expr(kind, expr)) + } + + fn fold_binop( + &mut self, + kind: BinOpKind, + lhs: ExprGen<'vir, Curr, Next>, + rhs: ExprGen<'vir, Curr, Next>, + ) -> ExprGen<'vir, Curr, Next> { + let lhs = self.fold(lhs); + let rhs = self.fold(rhs); + + with_vcx(move |vcx| vcx.mk_bin_op_expr(kind, lhs, rhs)) + } + + fn fold_ternary( + &mut self, + cond: ExprGen<'vir, Curr, Next>, + then: ExprGen<'vir, Curr, Next>, + else_: ExprGen<'vir, Curr, Next>, + ) -> ExprGen<'vir, Curr, Next> { + let cond = self.fold(cond); + let then = self.fold(then); + let else_ = self.fold(else_); + + with_vcx(move |vcx| vcx.mk_ternary_expr(cond, then, else_)) + } + + fn fold_forall( + &mut self, + qvars: &'vir [LocalDecl<'vir>], + triggers: &'vir [TriggerGen<'vir, Curr, Next>], + body: ExprGen<'vir, Curr, Next>, + ) -> ExprGen<'vir, Curr, Next> { + let triggers = self.fold_slice_slice(triggers); + let body = self.fold(body); + + with_vcx(move |vcx| vcx.mk_forall_expr(qvars, triggers, body)) + } + + fn fold_let( + &mut self, + name: &'vir str, + val: ExprGen<'vir, Curr, Next>, + expr: ExprGen<'vir, Curr, Next>, + ) -> ExprGen<'vir, Curr, Next> { + let val = self.fold(val); + let expr = self.fold(expr); + + with_vcx(move |vcx| vcx.mk_let_expr(name, val, expr)) + } + + fn fold_func_app( + &mut self, + target: &'vir str, + src_args: &'vir [ExprGen<'vir, Curr, Next>], + result_ty: Type<'vir>, + ) -> ExprGen<'vir, Curr, Next> { + let src_args = self.fold_slice(src_args); + + with_vcx(move |vcx| vcx.mk_func_app(target, src_args, result_ty)) + } + + fn fold_todo(&mut self, msg: &'vir str) -> ExprGen<'vir, Curr, Next> { + with_vcx(move |vcx| vcx.mk_todo_expr(msg)) + } + + fn fold_lazy(&mut self, lazy: LazyGen<'vir, Curr, Next>) -> ExprGen<'vir, Curr, Next> { + with_vcx(move |vcx| { + vcx.mk_lazy_expr( + lazy.name, + lazy.ty, + Box::new(move |ctx, c| { + let r = (lazy.func)(ctx, c); + // TODO + r + }), + ) + }) + } +} diff --git a/vir/src/lib.rs b/vir/src/lib.rs index d435c2a4378..58407779b28 100644 --- a/vir/src/lib.rs +++ b/vir/src/lib.rs @@ -17,6 +17,8 @@ mod serde; mod spans; mod callable_idents; mod viper_ident; +mod fold; +mod opt; pub use callable_idents::*; pub use context::*; @@ -27,6 +29,8 @@ pub use refs::*; pub use reify::*; pub use spans::VirSpan; pub use viper_ident::*; +pub use fold::*; +pub use opt::*; // for all arena-allocated types, there are two type definitions: one with // a `Data` suffix, containing the actual data; and one without the suffix, diff --git a/vir/src/make.rs b/vir/src/make.rs index 62629a46563..8a9e1d42cd4 100644 --- a/vir/src/make.rs +++ b/vir/src/make.rs @@ -103,6 +103,7 @@ cfg_if! { e: ExprGen<'vir, Curr, Next> ) { match e.kind { + ExprKindGenData::Result(..) => (), ExprKindGenData::Local(LocalData { name, ty, debug_info }) => { if let Some(bound_ty) = m.get(name) { if !matches!(bound_ty, TypeData::Unsupported(_)) && diff --git a/vir/src/opt.rs b/vir/src/opt.rs new file mode 100644 index 00000000000..a1a87024308 --- /dev/null +++ b/vir/src/opt.rs @@ -0,0 +1,276 @@ +use std::collections::HashMap; + +use crate::*; + +pub trait Optimizable: Sized { + fn optimize(&self) -> Self; +} + +impl<'vir, T> Optimizable for Option<&'vir T> +where + T: Optimizable, +{ + fn optimize(&self) -> Self { + self.map(|inner| { + let o = inner.optimize(); + with_vcx(move |vcx| vcx.alloc(o)) + }) + } +} + +impl<'vir, T> Optimizable for &'vir [&T] +where + T: Optimizable, +{ + fn optimize(&self) -> Self { + let v = self + .iter() + .map(|e| { + let e = e.optimize(); + with_vcx(|vcx| vcx.alloc(e)) + }) + .collect::>(); + with_vcx(move |vcx| vcx.alloc_slice(&v)) + } +} + +impl<'vir, Curr, Next> Optimizable for &'vir ExprGenData<'vir, Curr, Next> { + fn optimize(&self) -> Self { + let r = *self; + let s1 = (VariableOptimizerFolder { + rename: Default::default(), + }) + .fold(r); + + let s2 = EveryThingInliner::new().fold(s1); + BoolOptimizerFolder.fold(s2) + } +} + +struct BoolOptimizerFolder; + +impl<'vir, Cur, Next> ExprFolder<'vir, Cur, Next> for BoolOptimizerFolder { + // transforms `a == true` into `a` and `a == false` into `!a` + fn fold_binop( + &mut self, + kind: BinOpKind, + lhs: ExprGen<'vir, Cur, Next>, + rhs: ExprGen<'vir, Cur, Next>, + ) -> ExprGen<'vir, Cur, Next> { + let lhs = self.fold(lhs); + let rhs = self.fold(rhs); + + if let BinOpKind::CmpEq = kind { + if let ExprKindGenData::Const(ConstData::Bool(b)) = rhs.kind { + return if *b { + // case lhs == true + lhs + } else { + // case lhs == false + with_vcx(move |vcx| vcx.mk_unary_op_expr(UnOpKind::Not, lhs)) + }; + } + } + + with_vcx(move |vcx| vcx.mk_bin_op_expr(kind, lhs, rhs)) + } + + // Transforms `c? true : false` into `c` + fn fold_ternary( + &mut self, + cond: ExprGen<'vir, Cur, Next>, + then: ExprGen<'vir, Cur, Next>, + else_: ExprGen<'vir, Cur, Next>, + ) -> ExprGen<'vir, Cur, Next> { + let cond = self.fold(cond); + let then = self.fold(then); + let else_ = self.fold(else_); + + if let ( + ExprKindGenData::Const(ConstData::Bool(true)), + ExprKindGenData::Const(ConstData::Bool(false)), + ) = (then.kind, else_.kind) + { + return cond; + } + + with_vcx(move |vcx| vcx.mk_ternary_expr(cond, then, else_)) + } +} + +pub(crate) struct EveryThingInliner<'vir, Cur, Next> { + rename: HashMap<&'vir str, ExprGen<'vir, Cur, Next>>, +} + +impl<'vir, Cur, Next> EveryThingInliner<'vir, Cur, Next> { + fn new() -> Self { + Self { + rename: HashMap::new(), + } + } +} + +impl<'vir, Cur, Next> ExprFolder<'vir, Cur, Next> for EveryThingInliner<'vir, Cur, Next> { + fn fold_let( + &mut self, + name: &'vir str, + val: ExprGen<'vir, Cur, Next>, + expr: ExprGen<'vir, Cur, Next>, + ) -> ExprGen<'vir, Cur, Next> { + let val = self.fold(val); + + self.rename.insert(name, val); + + let expr = self.fold(expr); + + expr + } + + fn fold_local(&mut self, local: Local<'vir>) -> ExprGen<'vir, Cur, Next> { + let lcl = with_vcx(move |vcx| vcx.mk_local_ex_local(local)); + + self.rename.get(local.name).map(|e| *e).unwrap_or(lcl) + } + + // Transforms `C ? f(a) : f(b)` into `f(C? a : b)` + fn fold_ternary( + &mut self, + cond: ExprGen<'vir, Cur, Next>, + then: ExprGen<'vir, Cur, Next>, + else_: ExprGen<'vir, Cur, Next>, + ) -> ExprGen<'vir, Cur, Next> { + let cond = self.fold(cond); + let then = self.fold(then); + let else_ = self.fold(else_); + + if let (ExprKindGenData::FuncApp(then_app), ExprKindGenData::FuncApp(else_app)) = + (then.kind, else_.kind) + { + if then_app.args.len() == 1 + && else_app.args.len() == 1 + && else_app.target == then_app.target + && else_app.result_ty == then_app.result_ty + { + return with_vcx(move |vcx| { + vcx.mk_func_app( + then_app.target, + &[vcx.mk_ternary_expr(cond, then_app.args[0], else_app.args[0])], + then_app.result_ty, + ) + }); + } + } + + with_vcx(move |vcx| vcx.mk_ternary_expr(cond, then, else_)) + } + + // transforms `foo_read_x(foo_cons(a_1, ... a_n))` into a_x + fn fold_func_app( + &mut self, + target: &'vir str, + src_args: &'vir [ExprGen<'vir, Cur, Next>], + result_ty: Type<'vir>, + ) -> ExprGen<'vir, Cur, Next> { + let src_args = self.fold_slice(src_args); + let default = || with_vcx(move |vcx| vcx.mk_func_app(target, src_args, result_ty)); + + // Hacky way to do read of cons: + if src_args.len() != 1 { + return default(); + } + let ExprKindGenData::FuncApp(inner) = src_args[0].kind else { + return default(); + }; + if target.strip_prefix("make_generic_s_").is_some_and(|other| inner.target.strip_prefix("make_concrete_s_") == Some(other)) { + assert_eq!(inner.args.len(), 1); + return inner.args[0]; + } + if target.strip_prefix("make_concrete_s_").is_some_and(|other| inner.target.strip_prefix("make_generic_s_") == Some(other)) { + assert_eq!(inner.args.len(), 1); + return inner.args[0]; + } + + if target == "s_Ref_immutable_value" && inner.target == "s_Ref_immutable_cons" { + assert_eq!(inner.args.len(), 2); + return inner.args[1]; + } + let strip_both = |s: &'vir str, pre, post| { + s.strip_prefix(pre) + .and_then(move |s| s.strip_suffix(post)) + }; + if strip_both(target, "s_", "_value").is_some_and(|middle| + strip_both(inner.target, "s_", "_cons") == Some(middle)) { + assert_eq!(inner.args.len(), 1); + return inner.args[0]; + } + + // let Some((outer_lhs, read_nr)) = target.rsplit_once("_") else { + // return default(); + // }; + // let Some((start, "cons")) = inner.target.rsplit_once("_") else { + // return default(); + // }; + // if target.ends_with(&format!("read_{}", read_nr)) + // && target.starts_with(start) + // { + // if let Ok(read_nr) = read_nr.parse::() { + // return innerfuncapp.args[read_nr]; + // } else { + // println!("ERROR: Not a number: {} {}", innerfuncapp.target, target); + // } + // } + default() + } +} + +pub(crate) struct VariableOptimizerFolder<'vir> { + rename: HashMap, +} + +impl<'vir, Cur, Next> ExprFolder<'vir, Cur, Next> for VariableOptimizerFolder<'vir> { + fn fold_local(&mut self, local: Local<'vir>) -> ExprGen<'vir, Cur, Next> { + let nam = self + .rename + .get(local.name) + .map(|e| *e) + .unwrap_or(local.name); + with_vcx(move |vcx| vcx.mk_local_ex(&nam, local.ty)) + } + + fn fold_old(&mut self, expr: &'vir OldGenData<'vir, Cur, Next>) -> ExprGen<'vir, Cur, Next> { + expr.expr + } + + fn fold_let( + &mut self, + name: &'vir str, + val: ExprGen<'vir, Cur, Next>, + expr: ExprGen<'vir, Cur, Next>, + ) -> ExprGen<'vir, Cur, Next> { + let val = self.fold(val); + + match val.kind { + // let name = loc.name + ExprKindGenData::Local(loc) => { + let t = self + .rename + .get(loc.name) + .map(|e| e.to_owned()) + .unwrap_or(loc.name); + assert!(self.rename.insert(name.to_string(), t).is_none()); + return self.fold(expr); + } + _ => {} + } + + let expr = self.fold(expr); + + if let ExprKindGenData::Local(inner_local) = expr.kind { + if inner_local.name == name { + // if we encounter the case `let X = val in X` then just return `val` + return val; + } + } + with_vcx(move |vcx| vcx.mk_let_expr(name, val, expr)) + } +}