Skip to content
Draft
17 changes: 15 additions & 2 deletions prusti-encoder/src/encoders/const.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ use vir::{CallableIdent, Arity};

pub struct ConstEnc;


#[derive(Clone)]
pub struct ConstEncOutput<'vir>(pub vir::Expr<'vir>);


impl<'vir> task_encoder::Optimizable for ConstEncOutput<'vir> {}

impl<'vir> From<vir::Expr<'vir>> for ConstEncOutput<'vir> {
fn from(value: vir::Expr<'vir>) -> Self {
Self(value)
}
}

#[derive(Clone, Debug)]
pub struct ConstEncOutputRef<'vir> {
pub base_name: String,
Expand All @@ -28,7 +41,7 @@ impl TaskEncoder for ConstEnc {
usize, // current encoding depth
DefId, // DefId of the current function
);
type OutputFullLocal<'vir> = vir::Expr<'vir>;
type OutputFullLocal<'vir> = ConstEncOutput<'vir>;
type EncodingError = ();

fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir> {
Expand Down Expand Up @@ -94,6 +107,6 @@ impl TaskEncoder for ConstEnc {
}),
mir::ConstantKind::Ty(_) => todo!(),
};
Ok((res, ()))
Ok((res.into(), ()))
}
}
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ pub struct GenericEncOutput<'vir> {
pub domain_type: vir::Domain<'vir>,
}

impl<'vir> task_encoder::Optimizable for GenericEncOutput<'vir> {}


impl TaskEncoder for GenericEnc {
task_encoder::encoder_cache!(GenericEnc);

Expand Down
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ pub struct MirLocalDefEncOutput<'vir> {
}
pub type MirLocalDefEncError = ();


impl<'vir> task_encoder::Optimizable for MirLocalDefEncOutput<'vir> {}

#[derive(Clone, Copy)]
pub struct LocalDef<'vir> {
pub local: vir::Local<'vir>,
Expand Down
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/mir_builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ pub struct MirBuiltinEncOutput<'vir> {
pub function: vir::Function<'vir>,
}

impl<'vir> task_encoder::Optimizable for MirBuiltinEncOutput<'vir> {}


use crate::encoders::SnapshotEnc;

impl TaskEncoder for MirBuiltinEnc {
Expand Down
15 changes: 12 additions & 3 deletions prusti-encoder/src/encoders/mir_impure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use task_encoder::{
TaskEncoder,
TaskEncoderDependencies,
};
use vir::{MethodIdent, UnknownArity, CallableIdent};
use vir::{with_vcx, CallableIdent, MethodIdent, Optimizable, UnknownArity};

pub struct MirImpureEnc;

Expand All @@ -32,6 +32,15 @@ pub struct MirImpureEncOutput<'vir> {
pub method: vir::Method<'vir>,
}

impl<'vir> task_encoder::Optimizable for MirImpureEncOutput<'vir> {
fn optimize(self) -> Self {
let method = self.method.optimize();
let method = with_vcx(|vcx| vcx.alloc(method));
MirImpureEncOutput { method }
}
}


use crate::encoders::{PredicateEnc, ConstEnc, MirBuiltinEnc, MirFunctionEnc, MirLocalDefEnc, MirSpecEnc};

const ENCODE_REACH_BB: bool = false;
Expand Down Expand Up @@ -391,7 +400,7 @@ impl<'tcx, 'vir, 'enc> EncVisitor<'tcx, 'vir, 'enc> {
ty_out.ref_to_snap.apply(self.vcx, [self.encode_place(Place::from(source))])
}
mir::Operand::Constant(box constant) =>
self.deps.require_local::<ConstEnc>((constant.literal, 0, self.def_id)).unwrap()
self.deps.require_local::<ConstEnc>((constant.literal, 0, self.def_id)).unwrap().0
}
}

Expand All @@ -409,7 +418,7 @@ impl<'tcx, 'vir, 'enc> EncVisitor<'tcx, 'vir, 'enc> {
}
mir::Operand::Constant(box constant) => {
let ty_out = self.deps.require_ref::<PredicateEnc>(ty).unwrap();
let constant = self.deps.require_local::<ConstEnc>((constant.literal, 0, self.def_id)).unwrap();
let constant = self.deps.require_local::<ConstEnc>((constant.literal, 0, self.def_id)).unwrap().0;
(constant, ty_out)
}
};
Expand Down
6 changes: 5 additions & 1 deletion prusti-encoder/src/encoders/mir_pure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ pub struct MirPureEncOutput<'vir> {
pub expr: ExprRet<'vir>,
}

impl<'vir> task_encoder::Optimizable for MirPureEncOutput<'vir> {}



#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum PureKind {
Closure,
Expand Down Expand Up @@ -624,7 +628,7 @@ impl<'tcx, 'vir: 'enc, 'enc> Enc<'tcx, 'vir, 'enc>
mir::Operand::Copy(place)
| mir::Operand::Move(place) => self.encode_place(curr_ver, place),
mir::Operand::Constant(box constant) =>
self.deps.require_local::<ConstEnc>((constant.literal, self.encoding_depth, self.def_id)).unwrap().lift(),
self.deps.require_local::<ConstEnc>((constant.literal, self.encoding_depth, self.def_id)).unwrap().0.lift(),
}
}

Expand Down
11 changes: 10 additions & 1 deletion prusti-encoder/src/encoders/mir_pure_function.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use prusti_rustc_interface::{middle::{mir, ty}, span::def_id::DefId};

use task_encoder::{TaskEncoder, TaskEncoderDependencies};
use vir::{Reify, FunctionIdent, UnknownArity, CallableIdent};
use vir::{CallableIdent, FunctionIdent, Optimizable, Reify, UnknownArity};

use crate::encoders::{
MirPureEnc, MirPureEncTask, mir_pure::PureKind, MirSpecEnc, MirLocalDefEnc,
Expand All @@ -28,6 +28,15 @@ pub struct MirFunctionEncOutput<'vir> {
pub function: vir::Function<'vir>,
}

impl<'vir> task_encoder::Optimizable for MirFunctionEncOutput<'vir> {
fn optimize(self) -> Self {
let function = self.function.optimize();
let function = vir::with_vcx(|vcx| vcx.alloc(function));

MirFunctionEncOutput { function }
}
}

impl TaskEncoder for MirFunctionEnc {
task_encoder::encoder_cache!(MirFunctionEnc);

Expand Down
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/pure/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ pub struct MirSpecEncOutput<'vir> {
pub post_args: &'vir [vir::Expr<'vir>],
}

impl<'vir> task_encoder::Optimizable for MirSpecEncOutput<'vir> {}


impl TaskEncoder for MirSpecEnc {
task_encoder::encoder_cache!(MirSpecEnc);

Expand Down
5 changes: 3 additions & 2 deletions prusti-encoder/src/encoders/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use prusti_rustc_interface::{
};
use prusti_interface::specs::typed::{DefSpecificationMap, ProcedureSpecification};
use task_encoder::{
TaskEncoder,
TaskEncoderDependencies,
Optimizable, TaskEncoder, TaskEncoderDependencies
};

pub struct SpecEnc;
Expand All @@ -19,6 +18,8 @@ pub struct SpecEncOutput<'vir> {
pub posts: &'vir [DefId],
}

impl<'vir> Optimizable for SpecEncOutput<'vir> {}

use std::cell::RefCell;
thread_local! {
static DEF_SPEC_MAP: RefCell<Option<DefSpecificationMap>> = RefCell::new(Default::default());
Expand Down
24 changes: 18 additions & 6 deletions prusti-encoder/src/encoders/type/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,21 @@ pub struct DomainEncOutputRef<'vir> {
}
impl<'vir> task_encoder::OutputRefAny for DomainEncOutputRef<'vir> {}


#[derive(Clone)]
pub struct DomainEncOutput<'vir>(pub vir::Domain<'vir>);

impl<'vir> task_encoder::Optimizable for DomainEncOutput<'vir> {}

impl<'vir> From<vir::Domain<'vir>> for DomainEncOutput<'vir> {
fn from(value: vir::Domain<'vir>) -> Self {
DomainEncOutput(value)
}
}

use crate::encoders::SnapshotEnc;

pub fn all_outputs<'vir>() -> Vec<vir::Domain<'vir>> {
pub fn all_outputs<'vir>() -> Vec<DomainEncOutput<'vir>> {
DomainEnc::all_outputs()
}

Expand All @@ -90,7 +102,7 @@ impl TaskEncoder for DomainEnc {

type OutputRef<'vir> = DomainEncOutputRef<'vir>;
type OutputFullDependency<'vir> = DomainEncSpecifics<'vir>;
type OutputFullLocal<'vir> = vir::Domain<'vir>;
type OutputFullLocal<'vir> = DomainEncOutput<'vir>;
//type OutputFullDependency<'vir> = DomainEncOutputDep<'vir>;

type EncodingError = ();
Expand All @@ -109,7 +121,7 @@ impl TaskEncoder for DomainEnc {
Self::EncodingError,
Option<Self::OutputFullDependency<'vir>>,
)> {
vir::with_vcx(|vcx| match task_key.kind() {
(vir::with_vcx(|vcx| match task_key.kind() {
TyKind::Bool | TyKind::Char | TyKind::Int(_) | TyKind::Uint(_) | TyKind::Float(_) => {
let (base_name, prim_type) = match task_key.kind() {
TyKind::Bool => (String::from("Bool"), &vir::TypeData::Bool),
Expand Down Expand Up @@ -197,7 +209,7 @@ impl TaskEncoder for DomainEnc {
Ok((enc.finalize(), specifics))
}
kind => todo!("{kind:?}"),
})
}))
}
}

Expand Down Expand Up @@ -531,13 +543,13 @@ impl<'vir, 'tcx> DomainEncData<'vir, 'tcx> {
domain: self.domain,
}
}
fn finalize(self) -> vir::Domain<'vir> {
fn finalize(self) -> DomainEncOutput<'vir> {
self.vcx.mk_domain(
self.domain.name(),
self.domain.arity().args(),
self.vcx.alloc_slice(&self.axioms),
self.vcx.alloc_slice(&self.functions),
)
).into()
}
}

Expand Down
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/type/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ pub struct PredicateEncOutput<'vir> {
pub method_assign: vir::Method<'vir>,
}

impl<'vir> task_encoder::Optimizable for PredicateEncOutput<'vir> {}


use super::{snapshot::SnapshotEnc, domain::{DomainDataPrim, DomainDataStruct, DomainDataEnum, DiscrBounds}};

impl TaskEncoder for PredicateEnc {
Expand Down
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/type/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ pub struct SnapshotEncOutput<'vir> {
pub specifics: DomainEncSpecifics<'vir>,
}


impl<'vir> task_encoder::Optimizable for SnapshotEncOutput<'vir> {}

use super::domain::{DomainEnc, DomainEncSpecifics};

impl TaskEncoder for SnapshotEnc {
Expand Down
3 changes: 3 additions & 0 deletions prusti-encoder/src/encoders/type/viper_tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ pub struct ViperTupleEncOutput<'vir> {
tuple: Option<DomainDataStruct<'vir>>,
}

impl<'vir> task_encoder::Optimizable for ViperTupleEncOutput<'vir> {}


impl<'vir> ViperTupleEncOutput<'vir> {
pub fn mk_cons<'tcx, Curr, Next>(
&self,
Expand Down
34 changes: 27 additions & 7 deletions prusti-encoder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@ use prusti_rustc_interface::{
hir,
};


const ENABLE_OPTIMIZATION : bool = true;

// Wrapper Trait for task_encoder::Optimizable to allow toggling of optimization
// TODO: replace with config
trait MaybeOptimize {
fn optimize(self) -> Self;
}

impl<T> MaybeOptimize for T where T : task_encoder::Optimizable {
fn optimize(self) -> Self {
if ENABLE_OPTIMIZATION {
task_encoder::Optimizable::optimize(self)
}
else {
self
}
}
}

pub fn test_entrypoint<'tcx>(
tcx: ty::TyCtxt<'tcx>,
body: EnvBody<'tcx>,
Expand Down Expand Up @@ -63,34 +83,34 @@ pub fn test_entrypoint<'tcx>(
let mut viper_code = String::new();

header(&mut viper_code, "methods");
for output in crate::encoders::MirImpureEnc::all_outputs() {
for output in crate::encoders::MirImpureEnc::all_outputs().optimize() {
viper_code.push_str(&format!("{:?}\n", output.method));
}

header(&mut viper_code, "functions");
for output in crate::encoders::MirFunctionEnc::all_outputs() {
for output in crate::encoders::MirFunctionEnc::all_outputs().optimize() {
viper_code.push_str(&format!("{:?}\n", output.function));
}

header(&mut viper_code, "MIR builtins");
for output in crate::encoders::MirBuiltinEnc::all_outputs() {
for output in crate::encoders::MirBuiltinEnc::all_outputs().optimize() {
viper_code.push_str(&format!("{:?}\n", output.function));
}

header(&mut viper_code, "generics");
for output in crate::encoders::GenericEnc::all_outputs() {
for output in crate::encoders::GenericEnc::all_outputs().optimize() {
viper_code.push_str(&format!("{:?}\n", output.snapshot_param));
viper_code.push_str(&format!("{:?}\n", output.predicate_param));
viper_code.push_str(&format!("{:?}\n", output.domain_type));
}

header(&mut viper_code, "snapshots");
for output in crate::encoders::DomainEnc_all_outputs() {
viper_code.push_str(&format!("{:?}\n", output));
for output in crate::encoders::DomainEnc_all_outputs().optimize() {
viper_code.push_str(&format!("{:?}\n", output.0));
}

header(&mut viper_code, "types");
for output in crate::encoders::PredicateEnc::all_outputs() {
for output in crate::encoders::PredicateEnc::all_outputs().optimize() {
for field in output.fields {
viper_code.push_str(&format!("{:?}", field));
}
Expand Down
14 changes: 13 additions & 1 deletion task-encoder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ use std::cell::RefCell;
pub trait OutputRefAny {}
impl OutputRefAny for () {}


pub trait Optimizable: Sized {
fn optimize(self) -> Self {
self
}
}

impl<T> Optimizable for Vec<T> where T: Optimizable {
fn optimize(self) -> Self {
self.into_iter().map(|e|e.optimize()).collect()
}
}
pub enum TaskEncoderCacheState<'vir, E: TaskEncoder + 'vir + ?Sized> {
// None, // indicated by absence in the cache

Expand Down Expand Up @@ -177,7 +189,7 @@ pub trait TaskEncoder {
/// Fully encoded output for this task. When encoding items which can be
/// dependencies (such as methods), this output should only be emitted in
/// one Viper program.
type OutputFullLocal<'vir>: Clone;
type OutputFullLocal<'vir>: Clone + Optimizable;

/// Fully encoded output for this task for dependents. When encoding items
/// which can be dependencies (such as methods), this output should be
Expand Down
5 changes: 3 additions & 2 deletions vir/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub enum ConstData {
Null,
}

#[derive(PartialEq, Eq)]
pub enum TypeData<'vir> {
Int {
bit_width: u8,
Expand All @@ -102,12 +103,12 @@ pub enum TypeData<'vir> {
Unsupported(UnsupportedType<'vir>)
}

#[derive(Clone)]
#[derive(Clone, PartialEq, Eq)]
pub struct UnsupportedType<'vir> {
pub name: &'vir str,
}

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct DomainParamData<'vir> {
pub name: &'vir str, // TODO: identifiers
}
Expand Down
Loading