diff --git a/prusti-encoder/src/encoder_traits/pure_function_enc.rs b/prusti-encoder/src/encoder_traits/pure_function_enc.rs index b69685dcd2d..4ce6fc0acb2 100644 --- a/prusti-encoder/src/encoder_traits/pure_function_enc.rs +++ b/prusti-encoder/src/encoder_traits/pure_function_enc.rs @@ -13,7 +13,7 @@ use crate::encoders::{ domain::DomainEnc, lifted::{ func_def_ty_params::LiftedTyParamsEnc, - ty::{EncodeGenericsAsLifted, LiftedTy, LiftedTyEnc}, + ty::{EncodeGenericsAsLifted, LiftedTy, LiftedTyEnc, LiftedTyEncTask}, }, most_generic_ty::extract_type_params, GenericEnc, MirLocalDefEnc, MirPureEnc, MirPureEncTask, MirSpecEnc, PureKind, @@ -60,7 +60,7 @@ where ty: Ty<'vir>, ) -> Option> { let lifted_ty = deps - .require_local::>(ty) + .require_local::>(LiftedTyEncTask::Ty(ty)) .unwrap(); match lifted_ty { LiftedTy::Generic(generic) => { diff --git a/prusti-encoder/src/encoders/const.rs b/prusti-encoder/src/encoders/const.rs index a0c47e1445f..40dfc5bbcca 100644 --- a/prusti-encoder/src/encoders/const.rs +++ b/prusti-encoder/src/encoders/const.rs @@ -12,6 +12,22 @@ use prusti_rustc_interface::{ use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; use vir::{Arity, CallableIdent}; +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ConstEncTask<'vir> { + Mir { + const_: mir::Const<'vir>, + encoding_depth: usize, // current encoding depth + def_id: DefId, // DefId of the current function + }, +} + +/// Encodes constants into snapshot expressions. The evaluation of a constant +/// is assumed to be side-effect free, as enforced by the compiler. This encoder +/// handles two different kinds of constants: ones coming from the MIR and ones +/// coming from the type system. +/// +/// See "Representing constants" in the rustc dev guide for an overview: +/// https://rustc-dev-guide.rust-lang.org/mir/index.html#representing-constants pub struct ConstEnc; use crate::encoders::{mir_pure::PureKind, MirPureEnc, MirPureEncTask}; @@ -24,11 +40,7 @@ use super::{ impl TaskEncoder for ConstEnc { task_encoder::encoder_cache!(ConstEnc); - type TaskDescription<'vir> = ( - mir::Const<'vir>, - usize, // current encoding depth - DefId, // DefId of the current function - ); + type TaskDescription<'vir> = ConstEncTask<'vir>; type OutputFullLocal<'vir> = vir::Expr<'vir>; type EncodingError = (); @@ -41,78 +53,90 @@ impl TaskEncoder for ConstEnc { deps: &mut TaskEncoderDependencies<'vir, Self>, ) -> EncodeFullResult<'vir, Self> { deps.emit_output_ref(*task_key, ())?; - let (const_, encoding_depth, def_id) = *task_key; - let res = match const_ { - mir::Const::Val(val, ty) => { - let kind = deps - .require_local::(ty)? - .generic_snapshot - .specifics; - match val { - ConstValue::Scalar(Scalar::Int(int)) => { - let prim = kind.expect_primitive(); - let val = int.to_bits(int.size()); - let val = prim.expr_from_bits(ty, val); - vir::with_vcx(|vcx| prim.prim_to_snap.apply(vcx, [val])) - } - ConstValue::Scalar(Scalar::Ptr(ptr, _)) => vir::with_vcx(|vcx| { - match vcx.tcx().global_alloc(ptr.provenance.alloc_id()) { - GlobalAlloc::Function { .. } => todo!(), - GlobalAlloc::VTable(_, _) => todo!(), - GlobalAlloc::Static(_) => todo!(), - GlobalAlloc::Memory(_mem) => { - // If the `unwrap` ever panics we need a different way to get the inner type - // let inner_ty = ty.builtin_deref(true).map(|t| t.ty).unwrap_or(ty); - let _inner_ty = ty.builtin_deref(true).unwrap(); - todo!() + match *task_key { + ConstEncTask::Mir { const_, encoding_depth, def_id } => { + let res = match const_ { + mir::Const::Val(val, ty) => { + let kind = deps + .require_local::(ty)? + .generic_snapshot + .specifics; + match val { + ConstValue::Scalar(Scalar::Int(int)) => { + let prim = kind.expect_primitive(); + let val = int.to_bits(int.size()); + let val = prim.expr_from_bits(ty, val); + vir::with_vcx(|vcx| prim.prim_to_snap.apply(vcx, [val])) + } + ConstValue::Scalar(Scalar::Ptr(ptr, _)) => vir::with_vcx(|vcx| { + match vcx.tcx().global_alloc(ptr.provenance.alloc_id()) { + GlobalAlloc::Function { .. } => todo!(), + GlobalAlloc::VTable(_, _) => todo!(), + GlobalAlloc::Static(_) => todo!(), + GlobalAlloc::Memory(_mem) => { + // If the `unwrap` ever panics we need a different way to get the inner type + // let inner_ty = ty.builtin_deref(true).map(|t| t.ty).unwrap_or(ty); + let _inner_ty = ty.builtin_deref(true).unwrap(); + todo!() + } + } + }), + ConstValue::ZeroSized => { + let s = kind.expect_structlike(); + assert_eq!(s.field_snaps_to_snap.arity().args().len(), 0); + vir::with_vcx(|vcx| s.field_snaps_to_snap.apply(vcx, &[])) } + // Encode `&str` constants to an opaque domain. If we ever want to perform string reasoning + // we will need to revisit this encoding, but for the moment this allows assertions to avoid + // crashing Prusti. + ConstValue::Slice { .. } if ty.peel_refs().is_str() => { + let ref_ty = kind.expect_immref(); + let str_ty = ty.peel_refs(); + let str_snap = deps + .require_local::(str_ty)? + .generic_snapshot + .specifics + .expect_structlike(); + let cast = deps.require_local::>(str_ty)?; + vir::with_vcx(|vcx| { + // first, we create a string snapshot + let snap = str_snap.field_snaps_to_snap.apply(vcx, &[]); + // upcast it to a param + let snap = cast.cast_to_generic_if_necessary(vcx, snap); + // wrap it in a ref + ref_ty.prim_to_snap.apply(vcx, [vcx.mk_null(), snap]) + }) + } + ConstValue::Slice { .. } => todo!("ConstValue::Slice : {:?}", const_.ty()), + ConstValue::Indirect { .. } => todo!("ConstValue::Indirect"), } - }), - ConstValue::ZeroSized => { - let s = kind.expect_structlike(); - assert_eq!(s.field_snaps_to_snap.arity().args().len(), 0); - vir::with_vcx(|vcx| s.field_snaps_to_snap.apply(vcx, &[])) } - // Encode `&str` constants to an opaque domain. If we ever want to perform string reasoning - // we will need to revisit this encoding, but for the moment this allows assertions to avoid - // crashing Prusti. - ConstValue::Slice { .. } if ty.peel_refs().is_str() => { - let ref_ty = kind.expect_immref(); - let str_ty = ty.peel_refs(); - let str_snap = deps - .require_local::(str_ty)? + mir::Const::Unevaluated(uneval, _) => vir::with_vcx(|vcx| { + let task = MirPureEncTask { + encoding_depth: encoding_depth + 1, + parent_def_id: uneval.def, + param_env: vcx.tcx().param_env(uneval.def), + substs: ty::List::identity_for_item(vcx.tcx(), uneval.def), + kind: PureKind::Constant(uneval.promoted.unwrap()), + caller_def_id: Some(def_id), + }; + let expr = deps.require_local::(task)?.expr; + use vir::Reify; + Ok(expr.reify(vcx, (uneval.def, &[]))) + })?, + mir::Const::Ty(_, _) => vir::with_vcx(|vcx| { + deps + .require_local::(vcx.tcx().mk_ty_from_kind(ty::TyKind::Uint(ty::UintTy::Usize))) + .unwrap() .generic_snapshot .specifics - .expect_structlike(); - let cast = deps.require_local::>(str_ty)?; - vir::with_vcx(|vcx| { - // first, we create a string snapshot - let snap = str_snap.field_snaps_to_snap.apply(vcx, &[]); - // upcast it to a param - let snap = cast.cast_to_generic_if_necessary(vcx, snap); - // wrap it in a ref - ref_ty.prim_to_snap.apply(vcx, [vcx.mk_null(), snap]) - }) - } - ConstValue::Slice { .. } => todo!("ConstValue::Slice : {:?}", const_.ty()), - ConstValue::Indirect { .. } => todo!("ConstValue::Indirect"), - } - } - mir::Const::Unevaluated(uneval, _) => vir::with_vcx(|vcx| { - let task = MirPureEncTask { - encoding_depth: encoding_depth + 1, - parent_def_id: uneval.def, - param_env: vcx.tcx().param_env(uneval.def), - substs: ty::List::identity_for_item(vcx.tcx(), uneval.def), - kind: PureKind::Constant(uneval.promoted.unwrap()), - caller_def_id: Some(def_id), + .expect_primitive() + .prim_to_snap.apply(vcx, [vcx.mk_uint::<0>()]) // TODO + }), }; - let expr = deps.require_local::(task)?.expr; - use vir::Reify; - Ok(expr.reify(vcx, (uneval.def, &[]))) - })?, - mir::Const::Ty(_, _) => todo!("ConstantKind::Ty"), - }; - Ok((res, ())) + Ok((res, ())) + } + //_ => todo!(), + } } } diff --git a/prusti-encoder/src/encoders/generic.rs b/prusti-encoder/src/encoders/generic.rs index bb5ec38a5c5..54d5ece9f76 100644 --- a/prusti-encoder/src/encoders/generic.rs +++ b/prusti-encoder/src/encoders/generic.rs @@ -1,9 +1,28 @@ -use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; +use prusti_rustc_interface::{ + middle::ty::{self, TyKind}, + span::symbol, +}; +use task_encoder::{EncodeFullError, EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; use vir::{ BinaryArity, CallableIdent, DomainIdent, DomainParamData, FunctionIdent, KnownArityAny, NullaryArity, PredicateIdent, TypeData, UnaryArity, ViperIdent, }; +use super::rust_ty_predicates::RustTyPredicatesEnc; + +pub fn generic_enc_ref<'vir, E: TaskEncoder>( + deps: &mut TaskEncoderDependencies<'vir, E>, +) -> Result, EncodeFullError<'vir, E>> { + vir::with_vcx(|vcx| { + let ty_param = vcx.tcx().mk_ty_from_kind(TyKind::Param(ty::ParamTy { + index: 0u32, + name: symbol::Symbol::intern("T"), + })); + deps.require_ref::(ty_param) + })?; + deps.require_ref::(()) +} + pub struct GenericEnc; #[derive(Clone, Debug)] @@ -21,6 +40,8 @@ pub struct GenericEncOutputRef<'vir> { pub unreachable_to_snap: FunctionIdent<'vir, NullaryArity<'vir>>, // pub domain_type_name: DomainIdent<'vir, KnownArityAny<'vir, DomainParamData<'vir>, 0>>, pub domain_param_name: DomainIdent<'vir, KnownArityAny<'vir, DomainParamData<'vir>, 0>>, + pub const_type_function: vir::FunctionIdent<'vir, UnaryArity<'vir>>, + pub const_value_function: vir::FunctionIdent<'vir, UnaryArity<'vir>>, } impl<'vir> task_encoder::OutputRefAny for GenericEncOutputRef<'vir> {} @@ -78,6 +99,17 @@ impl TaskEncoder for GenericEnc { &TYP_DOMAIN, ); + let const_type_function = FunctionIdent::new( + ViperIdent::new("const_typ"), + UnaryArity::new(&[&SNAPSHOT_PARAM_DOMAIN]), + &TYP_DOMAIN, + ); + let const_value_function = FunctionIdent::new( + ViperIdent::new("const_val"), + UnaryArity::new(&[&TYP_DOMAIN]), + &SNAPSHOT_PARAM_DOMAIN, + ); + let output_ref = GenericEncOutputRef { type_snapshot: &TYP_DOMAIN, param_snapshot: &SNAPSHOT_PARAM_DOMAIN, @@ -87,17 +119,13 @@ impl TaskEncoder for GenericEnc { ref_to_snap, unreachable_to_snap, param_type_function, + const_type_function, + const_value_function, }; #[allow(clippy::unit_arg)] deps.emit_output_ref(*task_key, output_ref)?; - let typ = FunctionIdent::new( - ViperIdent::new("typ"), - UnaryArity::new(&[&SNAPSHOT_PARAM_DOMAIN]), - &TYP_DOMAIN, - ); - vir::with_vcx(|vcx| { let t = vcx.mk_local_ex("t", &TYP_DOMAIN); let ref_to_snap = vcx.mk_function( @@ -111,7 +139,7 @@ impl TaskEncoder for GenericEnc { ))]), vcx.alloc_slice(&[vcx.mk_bin_op_expr( vir::BinOpKind::CmpEq, - typ.apply(vcx, [vcx.mk_result(&SNAPSHOT_PARAM_DOMAIN)]), + param_type_function.apply(vcx, [vcx.mk_result(&SNAPSHOT_PARAM_DOMAIN)]), t, )]), None, @@ -125,7 +153,10 @@ impl TaskEncoder for GenericEnc { Ok(( GenericEncOutput { param_snapshot: vir::vir_domain! { vcx; domain s_Param { - function typ(s_Param): Type; + function param_type_function(s_Param): Type; + function const_type_function(s_Param): Type; + function const_value_function(Type): s_Param; + axiom_inverse(const_value_function, const_type_function, s_Param); } }, ref_to_pred: vir::vir_predicate! { vcx; predicate p_Param(self_p: Ref, t: Type) }, diff --git a/prusti-encoder/src/encoders/mir_builtin.rs b/prusti-encoder/src/encoders/mir_builtin.rs index a83a04795e2..e7cc7676bd7 100644 --- a/prusti-encoder/src/encoders/mir_builtin.rs +++ b/prusti-encoder/src/encoders/mir_builtin.rs @@ -13,6 +13,7 @@ pub enum MirBuiltinEncError { #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] #[allow(clippy::enum_variant_names)] pub enum MirBuiltinEncTask<'tcx> { + Len(ty::Ty<'tcx>), UnOp(ty::Ty<'tcx>, mir::UnOp, ty::Ty<'tcx>), BinOp(ty::Ty<'tcx>, mir::BinOp, ty::Ty<'tcx>, ty::Ty<'tcx>), CheckedBinOp(ty::Ty<'tcx>, mir::BinOp, ty::Ty<'tcx>, ty::Ty<'tcx>), @@ -53,22 +54,16 @@ impl TaskEncoder for MirBuiltinEnc { task_key: &Self::TaskKey<'vir>, deps: &mut TaskEncoderDependencies<'vir, Self>, ) -> EncodeFullResult<'vir, Self> { - vir::with_vcx(|vcx| match *task_key { + let function = vir::with_vcx(|vcx| match *task_key { + MirBuiltinEncTask::Len(e_ty) => Self::handle_len(vcx, deps, *task_key, e_ty), MirBuiltinEncTask::UnOp(res_ty, op, operand_ty) => { assert_eq!(res_ty, operand_ty); - let function = Self::handle_un_op(vcx, deps, *task_key, op, operand_ty)?; - Ok((MirBuiltinEncOutput { function }, ())) + Self::handle_un_op(vcx, deps, *task_key, op, operand_ty) } - MirBuiltinEncTask::BinOp(res_ty, op, l_ty, r_ty) => { - let function = Self::handle_bin_op(vcx, deps, *task_key, res_ty, op, l_ty, r_ty)?; - Ok((MirBuiltinEncOutput { function }, ())) - } - MirBuiltinEncTask::CheckedBinOp(res_ty, op, l_ty, r_ty) => { - let function = - Self::handle_checked_bin_op(vcx, deps, *task_key, res_ty, op, l_ty, r_ty)?; - Ok((MirBuiltinEncOutput { function }, ())) - } - }) + MirBuiltinEncTask::BinOp(res_ty, op, l_ty, r_ty) => Self::handle_bin_op(vcx, deps, *task_key, res_ty, op, l_ty, r_ty), + MirBuiltinEncTask::CheckedBinOp(res_ty, op, l_ty, r_ty) => Self::handle_checked_bin_op(vcx, deps, *task_key, res_ty, op, l_ty, r_ty), + })?; + Ok((MirBuiltinEncOutput { function }, ())) } } @@ -83,6 +78,44 @@ fn int_name(ty: ty::Ty<'_>) -> &'static str { } impl MirBuiltinEnc { + fn handle_len<'vir>( + vcx: &'vir vir::VirCtxt<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + key: ::TaskKey<'vir>, + arg_ty: ty::Ty<'vir>, + ) -> Result, EncodeFullError<'vir, Self>> { + let res_ty = vcx.tcx().mk_ty_from_kind(ty::TyKind::Uint(ty::UintTy::Usize)); + let arg_ty_enc = deps + .require_local::(arg_ty)? + .generic_snapshot; + let res_ty_enc = deps + .require_local::(res_ty)? + .generic_snapshot; + + let name = vir::vir_format_identifier!(vcx, "mir_len"); // TODO: name (Slice or Array) + let arity = UnknownArity::new(vcx.alloc_slice(&[arg_ty_enc.snapshot])); + let function = FunctionIdent::new(name, arity, res_ty_enc.snapshot); + deps.emit_output_ref(key, MirBuiltinEncOutputRef { function })?; + + let prim_arg_ty = arg_ty_enc.specifics.expect_array(); + let prim_res_ty = res_ty_enc.specifics.expect_primitive(); + let snap_arg = vcx.mk_local_ex("arg", arg_ty_enc.snapshot); + let prim_arg = prim_arg_ty.snap_to_prim.apply(vcx, [snap_arg]); + let val = prim_res_ty.prim_to_snap.apply( + vcx, + [vcx.mk_unary_op_expr(vir::UnOpKind::SeqLen, prim_arg)], + ); + + Ok(vcx.mk_function( + name.to_str(), + vcx.alloc_slice(&[vcx.mk_local_decl("arg", arg_ty_enc.snapshot)]), + res_ty_enc.snapshot, + &[], + &[], + Some(val), + )) + } + fn handle_un_op<'vir>( vcx: &'vir vir::VirCtxt<'vir>, deps: &mut TaskEncoderDependencies<'vir, Self>, diff --git a/prusti-encoder/src/encoders/mir_impure.rs b/prusti-encoder/src/encoders/mir_impure.rs index 4415803c49b..1e765e954e8 100644 --- a/prusti-encoder/src/encoders/mir_impure.rs +++ b/prusti-encoder/src/encoders/mir_impure.rs @@ -32,25 +32,21 @@ use crate::{ pure_func_app_enc::PureFuncAppEnc, }, encoders::{ - self, - lifted::{ + self, lifted::{ aggregate_cast::{AggregateSnapArgsCastEnc, AggregateSnapArgsCastEncTask}, casters::CastTypePure, func_app_ty_params::LiftedFuncAppTyParamsEnc, - }, - FunctionCallTaskDescription, MirBuiltinEnc, WandEnc, WandEncTask, + }, rust_ty_snapshots::RustTySnapshotsEnc, FunctionCallTaskDescription, GenericEnc, MirBuiltinEnc, WandEnc, WandEncTask }, }; use super::{ - lifted::{ + r#const::ConstEncTask, lifted::{ cast::{CastArgs, CastToEnc}, casters::CastTypeImpure, rust_ty_cast::RustTyCastersEnc, - ty::{EncodeGenericsAsLifted, LiftedTyEnc}, - }, - rust_ty_predicates::{RustTyPredicatesEnc, RustTyPredicatesEncOutputRef}, - ConstEnc, MirMonoImpureEnc, MirPolyImpureEnc, WandEncOutput, + ty::{EncodeGenericsAsLifted, LiftedTyEnc, LiftedTyEncTask}, + }, rust_ty_predicates::{RustTyPredicatesEnc, RustTyPredicatesEncOutputRef}, ConstEnc, MirMonoImpureEnc, MirPolyImpureEnc, WandEncOutput }; pub struct MirImpureEnc; @@ -509,8 +505,8 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> { fn pcg_repack(&mut self, repack_op: &RepackOp<'vir>) { match repack_op { - RepackOp::Expand(place, _target, capability_kind) - | RepackOp::Collapse(place, _target, capability_kind) => { + RepackOp::Expand(place, target, capability_kind) + | RepackOp::Collapse(place, target, capability_kind) => { if matches!(capability_kind, CapabilityKind::Write) { // Collapsing an already exhaled place is a no-op // TODO: unless it's through a Ref I imagine? @@ -558,7 +554,30 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> { .unwrap() .apply_cast_if_necessary(self.vcx, proj_app); }*/ - self.stmt(self.vcx.mk_unfold_stmt(predicate)); + match target.last_projection() { + Some((_, mir::PlaceElem::Index(index_local))) => { + let usize_ty_out = self + .deps + .require_local::(self.vcx.tcx().mk_ty_from_kind(ty::TyKind::Uint(ty::UintTy::Usize))) + .unwrap(); + let index_args = self.vcx.alloc_slice(&[usize_ty_out.generic_snapshot.specifics.expect_primitive().snap_to_prim.apply(self.vcx, [self.encode_operand_snap(&mir::Operand::Copy(index_local.into()))])] + .into_iter() + .chain(args.into_iter().copied()) + .collect::>()); + let unfold_index = place_ty_out + .generic_predicate + .expect_array() + .unfold_index; + self.stmt( + self.vcx.alloc(vir::StmtGenData::new( + self.vcx.alloc(unfold_index.apply(self.vcx, index_args)), + )), + ); + }, + _ => { + self.stmt(self.vcx.mk_unfold_stmt(predicate)); + } + } for (apply, _) in &casts { self.stmt(apply); } @@ -566,7 +585,30 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> { for (_, undo) in &casts { self.stmt(undo); } - self.stmt(self.vcx.mk_fold_stmt(predicate)); + match target.last_projection() { + Some((_, mir::PlaceElem::Index(index_local))) => { + let usize_ty_out = self + .deps + .require_local::(self.vcx.tcx().mk_ty_from_kind(ty::TyKind::Uint(ty::UintTy::Usize))) + .unwrap(); + let index_args = self.vcx.alloc_slice(&[usize_ty_out.generic_snapshot.specifics.expect_primitive().snap_to_prim.apply(self.vcx, [self.encode_operand_snap(&mir::Operand::Copy(index_local.into()))])] + .into_iter() + .chain(args.into_iter().copied()) + .collect::>()); + let fold_index = place_ty_out + .generic_predicate + .expect_array() + .fold_index; + self.stmt( + self.vcx.alloc(vir::StmtGenData::new( + self.vcx.alloc(fold_index.apply(self.vcx, index_args)), + )), + ); + }, + _ => { + self.stmt(self.vcx.mk_fold_stmt(predicate)); + } + } } } RepackOp::Weaken(place, CapabilityKind::Exclusive, CapabilityKind::Write) => { @@ -648,6 +690,11 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> { for elem in place.projection { if crossed_ref { use vir::Reify; + let index_expr = if let mir::ProjectionElem::Index(index) = elem { + Some(self.encode_place(mir::Place::from(index).into()).expr.lift()) + } else { + None + }; let (expr, _) = crate::encoders::mir_pure::encode_place_element( self.vcx, self.deps, @@ -655,6 +702,7 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> { elem, result.lift(), None, + index_expr, ); result = expr.reify(self.vcx, (self.def_id, &[])); } else { @@ -717,7 +765,11 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> { fn encode_constant(&mut self, constant: &mir::ConstOperand<'vir>) -> vir::Expr<'vir> { self.deps - .require_local::((constant.const_, 0, self.def_id)) + .require_local::(ConstEncTask::Mir { + const_: constant.const_, + encoding_depth: 0, + def_id: self.def_id, + }) .unwrap() } @@ -742,6 +794,21 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> { place: &EncodePlaceResult<'vir>, ) -> Vec<(vir::Stmt<'vir>, vir::Stmt<'vir>)> { match place.ty.ty.kind() { + TyKind::Array(elem_ty, _) => { + // TODO: make place_casts take the actual place as argument, so + // that we don't have to fake this: + let elem = mir::ProjectionElem::Index(0usize.into()); + let proj_app = self.encode_place_element(place.ty, elem, place.expr); + self.deps + .require_local::>( + *elem_ty, + ) + .unwrap() + .cast_to_concrete_if_possible(self.vcx, proj_app) + .into_iter() + .map(|cs| (cs.apply_cast_stmt, cs.unapply_cast_stmt)) + .collect() + } TyKind::Adt(def, _) if def.is_box() => { let proj_app = self.encode_place_element(place.ty, mir::ProjectionElem::Deref, place.expr); @@ -847,13 +914,31 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> { let projection_p = field_access[field_idx.as_usize()]; let instantiated_ty = self .deps - .require_local::>(place_ty.ty) + .require_local::>(LiftedTyEncTask::Ty(place_ty.ty)) .unwrap(); let proj_args = e_ty .generic_predicate .ref_to_args(self.vcx, instantiated_ty, expr); projection_p.apply(self.vcx, proj_args) } + mir::ProjectionElem::Index(v) => { + let e_ty = self + .deps + .require_ref::(place_ty.ty) + .unwrap(); + let index_access = e_ty + .generic_predicate + .expect_array() + .index_access; + let instantiated_ty = self + .deps + .require_local::>(LiftedTyEncTask::Ty(place_ty.ty)) + .unwrap(); + let proj_args = e_ty + .generic_predicate + .ref_to_args(self.vcx, instantiated_ty, expr); + index_access.apply(self.vcx, proj_args) + } // TODO: should all variants start at the same `Ref`? mir::ProjectionElem::Downcast(..) => expr, mir::ProjectionElem::Deref => { @@ -873,7 +958,7 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> { let projection_p = field_access[0]; let instantiated_ty = self .deps - .require_local::>(place_ty.ty) + .require_local::>(LiftedTyEncTask::Ty(place_ty.ty)) .unwrap(); let proj_args = e_ty.generic_predicate @@ -884,7 +969,7 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> { // TODO: unfold? function? use snapshot? let instantiated_ty = self .deps - .require_local::>(place_ty.ty) + .require_local::>(LiftedTyEncTask::Ty(place_ty.ty)) .unwrap(); let deref_args = e_ty.generic_predicate @@ -1094,9 +1179,19 @@ impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor< //mir::Rvalue::Repeat(Operand<'vir>, Const<'vir>) => {} //mir::Rvalue::ThreadLocalRef(DefId) => {} //mir::Rvalue::AddressOf(Mutability, Place<'vir>) => {} - //mir::Rvalue::Len(Place<'vir>) => {} //mir::Rvalue::Cast(CastKind, Operand<'vir>, Ty<'vir>) => {} + mir::Rvalue::Len(place) => { + let place_ty = place.ty(self.local_decls, self.vcx.tcx()); + let place_expr = self.encode_place_snap(Place::from(*place)).1; + let unop_function = self.deps.require_ref::( + crate::encoders::MirBuiltinEncTask::Len( + place_ty.ty, + ), + ).unwrap().function; + unop_function.apply(self.vcx, &[place_expr]) + } + mir::Rvalue::BinaryOp(op, box (l, r)) => { let l_ty = l.ty(self.local_decls, self.vcx.tcx()); let r_ty = r.ty(self.local_decls, self.vcx.tcx()); @@ -1144,6 +1239,26 @@ impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor< ))*/ } + mir::Rvalue::Aggregate( + box kind @ mir::AggregateKind::Array(elem_ty), + values, + ) => { + let generic_enc = self.deps.require_ref::(()).unwrap(); + let e_rvalue_ty = self.deps.require_ref::(rvalue_ty).unwrap(); + let prim = e_rvalue_ty.generic_predicate.expect_array(); + let ty_caster = self.deps.require_local::( + AggregateSnapArgsCastEncTask { + tys: std::iter::repeat(*elem_ty).take(values.len()).collect(), + aggregate_type: kind.into() + } + ).unwrap(); + let value_snaps = values.iter().map(|value| self.encode_operand_snap(value)).collect::>(); + let casted_values = ty_caster.apply_casts(self.vcx, value_snaps.into_iter()); + prim.snap_data.prim_to_snap.apply(self.vcx, [ + self.vcx.mk_seq_lit(self.vcx.alloc_slice(&casted_values), &generic_enc.param_snapshot), + ]) + } + mir::Rvalue::Aggregate( box kind @ (mir::AggregateKind::Adt(..) | mir::AggregateKind::Tuple), fields, diff --git a/prusti-encoder/src/encoders/mir_pure.rs b/prusti-encoder/src/encoders/mir_pure.rs index 60992414db3..e72001f7594 100644 --- a/prusti-encoder/src/encoders/mir_pure.rs +++ b/prusti-encoder/src/encoders/mir_pure.rs @@ -14,14 +14,11 @@ use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; use vir::add_debug_note; // TODO: replace uses of `PredicateEnc` with `SnapshotEnc` use super::{ - lifted::{ + r#const::ConstEncTask, lifted::{ aggregate_cast::{AggregateSnapArgsCastEnc, AggregateSnapArgsCastEncTask}, casters::CastTypePure, rust_ty_cast::RustTyCastersEnc, - }, - rust_ty_predicates::RustTyPredicatesEnc, - rust_ty_snapshots::RustTySnapshotsEnc, - GenericEnc, + }, rust_ty_predicates::RustTyPredicatesEnc, rust_ty_snapshots::RustTySnapshotsEnc, GenericEnc }; use crate::{ encoder_traits::pure_func_app_enc::PureFuncAppEnc, @@ -535,6 +532,84 @@ impl<'vir: 'enc, 'enc> Enc<'vir, 'enc> { mir::TerminatorKind::Return => stmt_update, + // TODO: there is some code duplication between here and SwitchInt + mir::TerminatorKind::Assert { + cond, + expected, + target, + .. + } => { + // encode the condition operand + let cond_ty = cond.ty(self.body, self.vcx.tcx()); + assert_eq!(*cond_ty.kind(), ty::TyKind::Bool); + let cond_expr = self.encode_operand(&new_curr_ver, cond); + let cond_ty_out = self + .deps + .require_local::(cond_ty) + .unwrap() + .generic_snapshot + .specifics + .expect_primitive(); + + // if cond == expected: walk the rest of the CFG + let ok_update = self.encode_cfg(&new_curr_ver, *target, join_point); + + // find locals updated in the "ok" branch, which were also + // defined before the branch + let mut mod_locals = ok_update.versions.keys() + .filter(|local| new_curr_ver.contains_key(local)) + .copied() + .collect::>(); + mod_locals.sort(); + mod_locals.dedup(); + + // if cond != expected: update locals to "unreachable" + let mut fail_update = Update::new(); + for local in &mod_locals { + let ty = self.body.local_decls[*local].ty; + let unreachable = self + .deps + .require_ref::(ty) + .unwrap() + .generic_predicate + .unreachable_to_snap; + self.bump_version(&mut fail_update, *local, unreachable.apply(self.vcx, [])); + } + + // create a Viper tuple of the updated locals + let tuple_ref = self + .deps + .require_local::(mod_locals.len()) + .unwrap(); + let phi_expr = self.vcx.mk_ternary_expr( + self.vcx.mk_bin_op_expr( + vir::BinOpKind::CmpEq, + cond_ty_out.snap_to_prim.apply(self.vcx, [cond_expr]), + cond_ty_out.expr_from_bits(cond_ty, if *expected { 1 } else { 0 }).lift(), + ), + self.reify_branch(&tuple_ref, &mod_locals, &new_curr_ver, ok_update), + self.reify_branch(&tuple_ref, &mod_locals, &new_curr_ver, fail_update), + ); + + // assign tuple into a `phi` variable + let phi_idx = self.phi_ctr; + self.phi_ctr += 1; + let mut phi_update = Update::new(); + phi_update.binds.push(UpdateBind::Phi(phi_idx, phi_expr)); + + // update locals by destructuring `phi` variable + // TODO: maybe this is unnecessary, we could instead use tuple + // access directly instead of the locals going forward? + for (elem_idx, local) in mod_locals.iter().enumerate() { + let ty = self.get_ty_for_local(*local); + let expr = self.mk_phi_acc(*local, tuple_ref.clone(), phi_idx, elem_idx, ty); + self.bump_version(&mut phi_update, *local, expr); + new_curr_ver.insert(*local, phi_update.versions[local]); + } + + stmt_update.merge(phi_update) + } + mir::TerminatorKind::Unreachable => { // update the return place to "unreachable" let mut end_update = Update::new(); @@ -688,7 +763,17 @@ impl<'vir: 'enc, 'enc> Enc<'vir, 'enc> { } // ThreadLocalRef // AddressOf - // Len + mir::Rvalue::Len(place) => { + let place_ty = place.ty(self.body, self.vcx.tcx()); + let unop_function = self + .deps + .require_ref::(crate::encoders::MirBuiltinEncTask::Len( + place_ty.ty + )) + .unwrap() + .function; + unop_function.apply(self.vcx, &[self.encode_place(curr_ver, place)]) + } // Cast mir::Rvalue::BinaryOp(op, box (l, r)) => { let l_ty = l.ty(self.body, self.vcx.tcx()); @@ -811,7 +896,11 @@ impl<'vir: 'enc, 'enc> Enc<'vir, 'enc> { } mir::Operand::Constant(box constant) => self .deps - .require_local::((constant.const_, self.encoding_depth, self.def_id)) + .require_local::(ConstEncTask::Mir { + const_: constant.const_, + encoding_depth: self.encoding_depth, + def_id: self.def_id, + }) .unwrap() .lift(), } @@ -856,7 +945,7 @@ impl<'vir: 'enc, 'enc> Enc<'vir, 'enc> { let mut place_ref = None; // TODO: factor this out (duplication with impure encoder)? for elem in place.projection { - (expr, place_ref) = self.encode_place_element(place_ty, elem, expr, place_ref); + (expr, place_ref) = self.encode_place_element(place_ty, elem, expr, place_ref, curr_ver); place_ty = place_ty.projection_ty(self.vcx.tcx(), elem); } // Can we ever have the use of a projected place? @@ -886,8 +975,14 @@ impl<'vir: 'enc, 'enc> Enc<'vir, 'enc> { elem: mir::PlaceElem<'vir>, expr: ExprRet<'vir>, place_ref: Option>, + curr_ver: &HashMap, ) -> (ExprRet<'vir>, Option>) { - encode_place_element(self.vcx, self.deps, place_ty, elem, expr, place_ref) + let index_expr = if let mir::PlaceElem::Index(index) = elem { + Some(self.encode_place_with_ref(curr_ver, &mir::Place::from(index)).0) + } else { + None + }; + encode_place_element(self.vcx, self.deps, place_ty, elem, expr, place_ref, index_expr) } fn encode_prusti_builtin( @@ -1281,6 +1376,7 @@ pub fn encode_place_element<'vir, 'enc, T: TaskEncoder>( elem: mir::PlaceElem<'vir>, expr: ExprRet<'vir>, place_ref: Option>, + index_expr: Option>, ) -> (ExprRet<'vir>, Option>) { match elem { mir::ProjectionElem::Deref => { @@ -1377,6 +1473,30 @@ pub fn encode_place_element<'vir, 'enc, T: TaskEncoder>( .map(|pr| struct_like.ref_to_field_refs[field_idx.as_usize()].apply(vcx, &[pr])); (proj_app, place_ref) } + mir::ProjectionElem::Index(..) => { + let ty::TyKind::Array(elem_ty, _) = place_ty.ty.kind() else { unreachable!("index but not array"); }; + let e_ty = deps + .require_ref::(place_ty.ty) + .unwrap(); + let array_enc = e_ty + .generic_predicate + .expect_array(); + let proj_app = array_enc.snap_data.snap_to_prim.apply(vcx, [expr]); + let usize_prim = deps + .require_local::(vcx.tcx().mk_ty_from_kind(ty::TyKind::Uint(ty::UintTy::Usize))) + .unwrap() + .generic_snapshot + .specifics + .expect_primitive(); + let proj_app = vcx.mk_bin_op_expr(vir::BinOpKind::SeqIndex, proj_app, usize_prim.snap_to_prim.apply(vcx, [index_expr.unwrap()])); + let proj_app = deps.require_local::>(*elem_ty) + .unwrap() + .cast_to_concrete_if_possible(vcx, proj_app); + let place_ref = place_ref + .map(|pr| array_enc.index_access.apply(vcx, &[pr])); + (proj_app, place_ref) + + } mir::ProjectionElem::Downcast(..) => (expr, place_ref), _ => todo!("Unsupported ProjectionElem {:?}", elem), } diff --git a/prusti-encoder/src/encoders/type/domain.rs b/prusti-encoder/src/encoders/type/domain.rs index 0036a89c951..36a4114ce85 100644 --- a/prusti-encoder/src/encoders/type/domain.rs +++ b/prusti-encoder/src/encoders/type/domain.rs @@ -2,16 +2,24 @@ // be an indirection in error storage somewhere, maybe even in `task-encoder`? #![allow(clippy::result_large_err)] -use prusti_rustc_interface::{ - middle::ty::{self, IntTy, ParamTy, TyKind, UintTy}, - span::symbol, - target::abi, -}; +use prusti_rustc_interface::middle::ty::{self, TyKind}; use task_encoder::{EncodeFullError, EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; use vir::{ BinaryArity, CallableIdent, DomainAxiomData, DomainFunctionData, DomainIdent, DomainParamData, FunctionIdent, NullaryArityAny, ToKnownArity, UnaryArity, UnknownArity, }; +use super::{ + lifted::ty_constructor::TyConstructorEnc, most_generic_ty::{get_vir_base_name_kind, MostGenericTy}, rust_ty_snapshots::RustTySnapshotsEnc, +}; + +pub use super::kinds::{ + adt::DomainDataEnum, + array::DomainDataArray, + immref::DomainDataImmRef, + mutref::DomainDataMutRef, + primitive::DomainDataPrim, + structlike::DomainDataStruct, +}; /// You probably never want to use this, use `SnapshotEnc` instead. /// Note: there should never be a dependency on `PredicateEnc` inside this @@ -27,56 +35,6 @@ pub struct FieldFunctions<'vir> { pub write: FunctionIdent<'vir, BinaryArity<'vir>>, } -#[derive(Clone, Copy, Debug)] -pub struct DomainDataPrim<'vir> { - pub prim_type: vir::Type<'vir>, - /// Snapshot of self as argument. Returns Viper primitive value. - pub snap_to_prim: FunctionIdent<'vir, UnaryArity<'vir>>, - /// Viper primitive value as argument. Returns domain. - pub prim_to_snap: FunctionIdent<'vir, UnaryArity<'vir>>, -} -#[derive(Clone, Copy, Debug)] -pub struct DomainDataImmRef<'vir> { - /// Construct domain from a `Ref` value. - pub prim_to_snap: FunctionIdent<'vir, BinaryArity<'vir>>, - /// Function to access the referee. - pub deref_access: FunctionIdent<'vir, UnaryArity<'vir>>, - /// Function to access the snapshot value. - pub value_access: FunctionIdent<'vir, UnaryArity<'vir>>, -} -#[derive(Clone, Copy, Debug)] -pub struct DomainDataMutRef<'vir> { - /// Construct domain from a `Ref` value. - pub prim_to_snap: FunctionIdent<'vir, BinaryArity<'vir>>, - /// Function to access the referee. - pub deref_access: FunctionIdent<'vir, UnaryArity<'vir>>, - /// Function to access the snapshot value. - pub value_access: FunctionIdent<'vir, UnaryArity<'vir>>, -} -#[derive(Clone, Copy, Debug)] -pub struct DomainDataStruct<'vir> { - /// Construct domain from snapshots of fields or for primitive types - /// from the single Viper primitive value. - pub field_snaps_to_snap: FunctionIdent<'vir, UnknownArity<'vir>>, - /// Functions to access the fields. - pub field_access: &'vir [FieldFunctions<'vir>], -} -#[derive(Clone, Copy, Debug)] -pub struct DomainDataEnum<'vir> { - pub discr_ty: vir::Type<'vir>, - pub discr_prim: DomainDataPrim<'vir>, - //pub discr_bounds: DiscrBounds<'vir>, - pub snap_to_discr_snap: FunctionIdent<'vir, UnaryArity<'vir>>, - pub variants: &'vir [DomainDataVariant<'vir>], -} -#[derive(Clone, Copy, Debug)] -pub struct DomainDataVariant<'vir> { - pub name: symbol::Symbol, - pub vid: abi::VariantIdx, - pub discr: vir::Expr<'vir>, - pub fields: DomainDataStruct<'vir>, -} - #[derive(Clone, Copy, Debug)] pub enum DiscrBounds<'vir> { Range { @@ -91,6 +49,7 @@ pub enum DomainEncSpecifics<'vir> { Opaque, Param, Never, + Array(DomainDataArray<'vir>), Primitive(DomainDataPrim<'vir>), ImmRef(DomainDataImmRef<'vir>), MutRef(DomainDataMutRef<'vir>), @@ -123,15 +82,6 @@ impl<'vir> DomainEncOutputRef<'vir> { impl<'vir> task_encoder::OutputRefAny for DomainEncOutputRef<'vir> {} -use super::{ - lifted::{ - ty::{EncodeGenericsAsParamTy, LiftedTy, LiftedTyEnc}, - ty_constructor::TyConstructorEnc, - }, - most_generic_ty::{extract_type_params, get_vir_base_name_kind, MostGenericTy}, - rust_ty_snapshots::RustTySnapshotsEnc, -}; - pub fn all_outputs<'vir>() -> Vec> { DomainEnc::all_outputs().into_iter().flatten().collect() } @@ -187,6 +137,9 @@ impl TaskEncoder for DomainEnc { | TyKind::Float(_) => { super::kinds::primitive::domain(*task_key, deps, &mut builder)? } + TyKind::Array(..) => { + super::kinds::array::domain(*task_key, &output_ref, deps, &mut builder)? + } TyKind::Closure(..) => { super::kinds::closure::domain(*task_key, &output_ref, deps, &mut builder)? } @@ -215,7 +168,6 @@ impl TaskEncoder for DomainEnc { pub(crate) struct DomainBuilder<'vir> { pub(crate) vcx: &'vir vir::VirCtxt<'vir>, name: Option<&'vir str>, - generics: Option>>, domain_ident: Option>>>, self_type: Option>, axioms: Vec>, @@ -227,7 +179,6 @@ impl<'vir> DomainBuilder<'vir> { DomainBuilder { vcx, name: None, - generics: None, domain_ident: None, self_type: None, axioms: Vec::new(), @@ -245,10 +196,6 @@ impl<'vir> DomainBuilder<'vir> { ))); } - pub(crate) fn set_generics(&mut self, generics: Vec>) { - self.generics = Some(generics); - } - pub(crate) fn function( &mut self, name: &str, @@ -313,87 +260,6 @@ impl<'vir> DomainBuilder<'vir> { } } -// Utility functions - -impl<'vir> DomainEncSpecifics<'vir> { - #[track_caller] - pub fn expect_primitive(self) -> DomainDataPrim<'vir> { - match self { - Self::Primitive(data) => data, - _ => panic!("expected primitive"), - } - } - #[track_caller] - pub fn expect_immref(self) -> DomainDataImmRef<'vir> { - match self { - Self::ImmRef(data) => data, - _ => panic!("expected immref"), - } - } - #[track_caller] - pub fn expect_mutref(self) -> DomainDataMutRef<'vir> { - match self { - Self::MutRef(data) => data, - _ => panic!("expected mutref"), - } - } - #[track_caller] - pub fn expect_structlike(self) -> DomainDataStruct<'vir> { - match self { - Self::StructLike(data) => data, - _ => panic!("expected struct-like (was {self:?}"), - } - } - pub fn get_enumlike(self) -> Option>> { - match self { - Self::EnumLike(data) => Some(data), - _ => None, - } - } - #[track_caller] - pub fn expect_enumlike(self) -> Option> { - match self { - Self::EnumLike(data) => data, - _ => panic!("expected enum-like, was {self:?}"), - } - } -} -impl<'vir> DomainDataPrim<'vir> { - pub fn expr_from_bits(&self, ty: ty::Ty<'vir>, value: u128) -> vir::Expr<'vir> { - match *self.prim_type { - vir::TypeData::Bool => { - vir::with_vcx(|vcx| vcx.mk_const_expr(vir::ConstData::Bool(value != 0))) - } - vir::TypeData::Int => { - let (bit_width, signed) = match ty.kind() { - TyKind::Int(IntTy::Isize) => ((std::mem::size_of::() * 8) as u64, true), - TyKind::Int(ty) => (ty.bit_width().unwrap(), true), - TyKind::Uint(UintTy::Usize) => { - ((std::mem::size_of::() * 8) as u64, true) - } - TyKind::Uint(ty) => (ty.bit_width().unwrap(), false), - kind => unreachable!("{kind:?}"), - }; - let size = abi::Size::from_bits(bit_width); - let negative_value = if signed { - let value = size.sign_extend(value); - Some(value).filter(|value| value.is_negative()) - } else { - None - }; - match negative_value { - Some(value) => vir::with_vcx(|vcx| { - let value = vcx.mk_const_expr(vir::ConstData::Int(value.unsigned_abs())); - vcx.mk_unary_op_expr(vir::UnOpKind::Neg, value) - }), - None => vir::with_vcx(|vcx| vcx.mk_const_expr(vir::ConstData::Int(value))), - } - } - ref k => unreachable!("{k:?}"), - } - } -} - /// Data for encoding field access functions and axioms #[derive(Clone)] pub(super) struct FieldTy<'vir> { @@ -401,24 +267,10 @@ pub(super) struct FieldTy<'vir> { /// The type of encoded field pub(super) ty: vir::Type<'vir>, - - /// Information about the Rust type, only defined for fields that correspond - /// to actual Rust types. For example, this will be `None` for a Viper - /// `Bool` field encoded as part of the snapshot encoding of the rust bool - /// type. - pub(super) rust_ty_data: Option>, -} - -#[derive(Clone)] -pub(super) struct LiftedRustTyData<'vir> { - /// The representation of the Rust type of the field - lifted_ty: LiftedTy<'vir, ParamTy>, - /// Takes as input the value of the field, and returns its type - typeof_function: FunctionIdent<'vir, UnaryArity<'vir>>, } impl<'vir> FieldTy<'vir> { - pub fn mk_field_tys( + pub(super) fn mk_field_tys( vcx: &'vir vir::VirCtxt<'vir>, deps: &mut TaskEncoderDependencies<'vir, T>, variant: &ty::VariantDef, @@ -433,7 +285,7 @@ impl<'vir> FieldTy<'vir> { } pub(super) fn from_ty( - vcx: &'vir vir::VirCtxt<'vir>, + _vcx: &'vir vir::VirCtxt<'vir>, deps: &mut TaskEncoderDependencies<'vir, T>, ty: ty::Ty<'vir>, ) -> Result, EncodeFullError<'vir, T>> { @@ -441,17 +293,9 @@ impl<'vir> FieldTy<'vir> { .require_ref::(ty)? .generic_snapshot .snapshot; - let typeof_function = deps - .require_ref::(extract_type_params(vcx.tcx(), ty).0)? - .typeof_function; - let lifted_ty = deps.require_local::>(ty)?; Ok(FieldTy { rust_ty: ty, ty: vir_ty, - rust_ty_data: Some(LiftedRustTyData { - lifted_ty, - typeof_function, - }), }) } } diff --git a/prusti-encoder/src/encoders/type/kinds/adt.rs b/prusti-encoder/src/encoders/type/kinds/adt.rs index 30053756719..72ae44b8e89 100644 --- a/prusti-encoder/src/encoders/type/kinds/adt.rs +++ b/prusti-encoder/src/encoders/type/kinds/adt.rs @@ -1,21 +1,135 @@ use crate::encoders::{ domain::{ - DomainBuilder, DomainDataEnum, DomainDataStruct, DomainDataVariant, DomainEnc, + DomainBuilder, DomainDataStruct, DomainEnc, DomainEncOutputRef, DomainEncSpecifics, FieldTy, }, - lifted::ty::{EncodeGenericsAsParamTy, LiftedTyEnc}, + lifted::ty::{EncodeGenericsAsParamTy, LiftedTyEnc, LiftedTyEncTask}, predicate::{ - PredicateBuilder, PredicateEncData, PredicateEncDataEnum, PredicateEncDataStruct, - PredicateEncDataVariant, + PredicateBuilder, PredicateEncData, PredicateEncDataStruct, }, rust_ty_predicates::RustTyPredicatesEnc, rust_ty_snapshots::RustTySnapshotsEnc, snapshot::SnapshotEncOutput, - PredicateEnc, + PredicateEnc, PredicateEncOutputRef, +}; +use prusti_rustc_interface::{ + middle::ty, + span::symbol, + abi, }; -use prusti_rustc_interface::middle::ty; use task_encoder::{EncodeFullError, TaskEncoder, TaskEncoderDependencies}; -use vir::ToKnownArity; +use vir::{FunctionIdent, PredicateIdent, ToKnownArity, UnaryArity, UnknownArity}; + +use super::primitive::DomainDataPrim; + +#[derive(Clone, Copy, Debug)] +pub struct DomainDataEnum<'vir> { + pub discr_ty: vir::Type<'vir>, + pub discr_prim: DomainDataPrim<'vir>, + //pub discr_bounds: DiscrBounds<'vir>, + pub snap_to_discr_snap: FunctionIdent<'vir, UnaryArity<'vir>>, + pub variants: &'vir [DomainDataVariant<'vir>], +} + +#[derive(Clone, Copy, Debug)] +pub struct DomainDataVariant<'vir> { + pub name: symbol::Symbol, + pub vid: abi::VariantIdx, + pub discr: vir::Expr<'vir>, + pub fields: DomainDataStruct<'vir>, +} + +impl<'vir> DomainEncSpecifics<'vir> { + pub fn get_enumlike(self) -> Option>> { + match self { + Self::EnumLike(data) => Some(data), + _ => None, + } + } + + #[track_caller] + pub fn expect_enumlike(self) -> Option> { + match self { + Self::EnumLike(data) => data, + _ => panic!("expected enumlike domain data (got {self:?})"), + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct PredicateEncDataEnum<'vir> { + pub discr: FunctionIdent<'vir, UnaryArity<'vir>>, + pub discr_prim: DomainDataPrim<'vir>, + //pub discr_bounds: DiscrBounds<'vir>, + // pub snap_to_discr_snap: FunctionIdent<'vir, UnaryArity<'vir>>, + pub variants: &'vir [PredicateEncDataVariant<'vir>], +} + +#[derive(Clone, Copy, Debug)] +pub struct PredicateEncDataVariant<'vir> { + pub predicate: PredicateIdent<'vir, UnknownArity<'vir>>, + pub vid: abi::VariantIdx, + pub discr: vir::Expr<'vir>, + pub fields: PredicateEncDataStruct<'vir>, +} + +impl<'vir> PredicateEncOutputRef<'vir> { + pub fn get_enumlike(&self) -> Option<&Option>> { + match &self.specifics { + PredicateEncData::EnumLike(e) => Some(e), + _ => None, + } + } + + #[track_caller] + pub fn expect_enumlike(&self) -> Option<&PredicateEncDataEnum<'vir>> { + match &self.specifics { + PredicateEncData::EnumLike(data) => data.as_ref(), + s => panic!("expected enumlike predicate data (got {s:?})"), + } + } + + pub fn get_variant_any(&self, vid: abi::VariantIdx) -> &PredicateEncDataStruct<'vir> { + match &self.specifics { + PredicateEncData::StructLike(s) => { + assert_eq!(vid, abi::FIRST_VARIANT); + s + } + PredicateEncData::EnumLike(e) => &e.as_ref().unwrap().variants[vid.as_usize()].fields, + s => panic!("expected structlike or enumlike predicate data (got {s:?})"), + } + } + + #[track_caller] + pub fn expect_variant(&self, vid: abi::VariantIdx) -> &PredicateEncDataVariant<'vir> { + match &self.specifics { + PredicateEncData::EnumLike(e) => &e.as_ref().unwrap().variants[vid.as_usize()], + s => panic!("expected enumlike predicate data (got {s:?})"), + } + } + + #[track_caller] + pub fn expect_pred_variant_opt( + &self, + vid: Option, + ) -> PredicateIdent<'vir, UnknownArity<'vir>> { + vid.map(|vid| self.expect_variant(vid).predicate) + .unwrap_or(self.ref_to_pred) + } + + #[track_caller] + pub fn expect_variant_opt( + &self, + vid: Option, + ) -> &PredicateEncDataStruct<'vir> { + match vid { + None => self.expect_structlike(), + Some(vid) => { + &self.expect_enumlike().expect("empty enum").variants[vid.as_usize()].fields + } + } + } +} pub(crate) fn domain<'vir>( task_key: ::TaskKey<'vir>, @@ -33,7 +147,7 @@ pub(crate) fn domain<'vir>( .iter() .flat_map(ty::GenericArg::as_type) .map(|ty| { - deps.require_local::>(ty) + deps.require_local::>(LiftedTyEncTask::Ty(ty)) .unwrap() .expect_generic() }) diff --git a/prusti-encoder/src/encoders/type/kinds/array.rs b/prusti-encoder/src/encoders/type/kinds/array.rs new file mode 100644 index 00000000000..ef2be367c1d --- /dev/null +++ b/prusti-encoder/src/encoders/type/kinds/array.rs @@ -0,0 +1,222 @@ +use crate::encoders::{ + domain::{DomainBuilder, DomainEnc, DomainEncOutputRef, DomainEncSpecifics}, predicate::{PredicateBuilder, PredicateEncData}, rust_ty_snapshots::RustTySnapshotsEnc, snapshot::SnapshotEncOutput, GenericEnc, PredicateEnc, PredicateEncOutputRef +}; +use prusti_rustc_interface::middle::ty; +use task_encoder::{EncodeFullError, TaskEncoder, TaskEncoderDependencies}; +use vir::{FunctionIdent, MethodIdent, ToKnownArity, UnaryArity, UnknownArity}; + +#[derive(Clone, Copy, Debug)] +pub struct DomainDataArray<'vir> { + pub prim_type: vir::Type<'vir>, + /// Snapshot of self as argument. Returns Viper primitive value. + pub snap_to_prim: FunctionIdent<'vir, UnaryArity<'vir>>, + /// Viper primitive value as argument. Returns domain. + pub prim_to_snap: FunctionIdent<'vir, UnaryArity<'vir>>, +} + +impl<'vir> DomainEncSpecifics<'vir> { + #[track_caller] + pub fn expect_array(self) -> DomainDataArray<'vir> { + match self { + Self::Array(data) => data, + _ => panic!("expected array domain data (got {self:?})"), + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct PredicateEncDataArray<'vir> { + pub snap_data: DomainDataArray<'vir>, + pub index_access: FunctionIdent<'vir, UnknownArity<'vir>>, + pub unfold_index: MethodIdent<'vir, UnknownArity<'vir>>, + pub fold_index: MethodIdent<'vir, UnknownArity<'vir>>, +} + +impl<'vir> PredicateEncOutputRef<'vir> { + #[track_caller] + pub fn expect_array(&self) -> &PredicateEncDataArray<'vir> { + match &self.specifics { + PredicateEncData::Array(data) => data, + s => panic!("expected array predicate data (got {s:?})"), + } + } +} + +pub(crate) fn domain<'vir>( + task_key: ::TaskKey<'vir>, + _output_ref: &DomainEncOutputRef<'vir>, + deps: &mut TaskEncoderDependencies<'vir, DomainEnc>, + builder: &mut DomainBuilder<'vir>, +) -> Result, EncodeFullError<'vir, DomainEnc>> { + let ty = task_key.ty(); + let ty_kind = ty.kind(); + let ty::TyKind::Array(elem_ty, _) = ty_kind else { unreachable!() }; + let elem_ty_enc = deps.require_ref::(*elem_ty)?; + let prim_type = builder.vcx.mk_ty_seq(elem_ty_enc.generic_snapshot.snapshot); + + let value_ident = builder.function("value", &[builder.self_type()], prim_type); + let cons_ident = builder.function("cons", &[prim_type], builder.self_type()); + + builder.axiom("cons", vir::expr! { + forall s: [builder.self_type()] :: {[value_ident](s)} ([cons_ident]([value_ident](s))) == (s) + }); + builder.axiom("value", vir::expr! { + forall value: [prim_type] :: {[cons_ident](value)} ([value_ident]([cons_ident](value))) == (value) + }); + + Ok(DomainEncSpecifics::Array(DomainDataArray { + prim_type, + snap_to_prim: value_ident.to_known(), + prim_to_snap: cons_ident.to_known(), + })) +} + +pub(crate) fn predicate<'vir>( + _task_key: ::TaskKey<'vir>, + snap: SnapshotEncOutput<'vir>, + deps: &mut TaskEncoderDependencies<'vir, PredicateEnc>, + generic_decls: &[vir::LocalDecl<'vir>], + generic_exprs: &[vir::Expr<'vir>], + builder: &mut PredicateBuilder<'vir>, +) -> Result< + ( + PredicateEncData<'vir>, + Option, vir::ExprKind<'vir>>>, + ), + EncodeFullError<'vir, PredicateEnc>, +> { + // let ty = task_key.ty(); + // let ty_kind = ty.kind(); + + let snap_type = snap.snapshot; + let snap_data = snap.specifics.expect_array(); + let generic_enc = deps.require_ref::(())?; + + let ref_self = builder.vcx.mk_local("self", &vir::TypeData::Ref); + let ref_self_decl = builder.vcx.mk_local_decl_local(ref_self); + + // main predicate + let self_pred = builder.predicate( + "", + &[ref_self_decl] + .into_iter() + .chain(generic_decls.iter().cloned()) + .collect::>(), + None, + ); + + // Ref-to-snap + let (snap_ident, snap_func) = builder.mk_function( + "snap", + &[ref_self_decl] + .into_iter() + .chain(generic_decls.iter().cloned()) + .collect::>(), + snap_type, + &[vir::expr! { acc_wildcard([self_pred](ref_self, ..[generic_exprs])) }], + &[vir::expr! { forall i: Int :: { [builder.vcx.mk_bin_op_expr(vir::BinOpKind::SeqIndex, vir::expr! { [snap_data.snap_to_prim]([builder.vcx.mk_result(snap_type)]) }, vir::expr! { i })] } + (((0) <= (i)) && ((i) < (vpr_seq_len([snap_data.snap_to_prim]([builder.vcx.mk_result(snap_type)]))))) + ==> (([generic_enc.param_type_function]([builder.vcx.mk_bin_op_expr(vir::BinOpKind::SeqIndex, vir::expr! { [snap_data.snap_to_prim]([builder.vcx.mk_result(snap_type)]) }, vir::expr! { i })])) == ([generic_exprs[0]])) + }], + None, + ); + builder.function_snap = Some(snap_func); + + // "borrowed" predicate, to frame across index accesses + let borrowed_pred = builder.predicate( + "borrowed", + &[ref_self_decl] + .into_iter() + .chain(generic_decls.iter().cloned()) + .collect::>(), + None, + ); + let borrowed_snap = builder.function( + "borrowed_snap", + &[ref_self_decl] + .into_iter() + .chain(generic_decls.iter().cloned()) + .collect::>(), + snap_type, + &[vir::expr! { acc_wildcard([borrowed_pred](ref_self, ..[generic_exprs])) }], + &[], + None, + ); + + let index_access = builder.function( + "index", + &[ref_self_decl].into_iter() + .chain(generic_decls.iter().cloned()) + .collect::>(), + &vir::TypeData::Ref, + &[], // TODO: should have a read permission here! + &[], + None, + ); + + // unfold/fold index + let self_snap = vir::expr! { [snap_ident](ref_self, ..[generic_exprs]) }; + let self_val = vir::expr! { [snap_data.snap_to_prim](self_snap) }; + let index = builder.vcx.mk_local("index", &vir::TypeData::Int); + let index_decl = builder.vcx.mk_local_decl_local(index); + let index_val = builder.vcx.mk_bin_op_expr(vir::BinOpKind::SeqIndex, self_val, vir::expr! { index }); + + let unfold_index = builder.method( + "unfold_index", + &[index_decl, ref_self_decl] + .into_iter() + .chain(generic_decls.iter().cloned()) + .collect::>(), + &[], + &[ + vir::expr! { acc([self_pred](ref_self, ..[generic_exprs])) }, + vir::expr! { ((0) <= (index)) && ((index) < (vpr_seq_len(self_val))) }, + ], + &[ + vir::expr! { acc([borrowed_pred](ref_self, ..[generic_exprs])) }, + vir::expr! { acc([generic_enc.ref_to_pred]([index_access](ref_self, ..[generic_exprs]), ..[generic_exprs])) }, + vir::expr! { ([borrowed_snap](ref_self, ..[generic_exprs])) == (old(self_snap)) }, + vir::expr! { ([generic_enc.ref_to_snap]([index_access](ref_self, ..[generic_exprs]), ..[generic_exprs])) == (old(index_val)) }, + ], + ); + + let fold_index = builder.method( + "fold_index", + &[index_decl, ref_self_decl] + .into_iter() + .chain(generic_decls.iter().cloned()) + .collect::>(), + &[], + &[ + vir::expr! { acc([borrowed_pred](ref_self, ..[generic_exprs])) }, + vir::expr! { acc([generic_enc.ref_to_pred]([index_access](ref_self, ..[generic_exprs]), ..[generic_exprs])) }, + vir::expr! { ((0) <= (index)) && ((index) < (vpr_seq_len([snap_data.snap_to_prim]([borrowed_snap](ref_self, ..[generic_exprs]))))) }, + ], + &[ + vir::expr! { acc([self_pred](ref_self, ..[generic_exprs])) }, + vir::expr! { (vpr_seq_len(self_val)) == (old(vpr_seq_len([snap_data.snap_to_prim]([borrowed_snap](ref_self, ..[generic_exprs]))))) }, + vir::expr! { + forall i: [&vir::TypeData::Int] :: {[builder.vcx.mk_bin_op_expr(vir::BinOpKind::SeqIndex, self_val, vir::expr! { i })]} (((0) <= (i)) && ((i) < (vpr_seq_len(self_val)))) + ==> (([builder.vcx.mk_bin_op_expr(vir::BinOpKind::SeqIndex, self_val, vir::expr! { i })]) == ([builder.vcx.mk_ternary_expr( + vir::expr! { (i) == (index) }, + vir::expr! { old([generic_enc.ref_to_snap]([index_access](ref_self, ..[generic_exprs]), ..[generic_exprs])) }, + vir::expr! { old([builder.vcx.mk_bin_op_expr( + vir::BinOpKind::SeqIndex, + vir::expr! { [snap_data.snap_to_prim]([borrowed_snap](ref_self, ..[generic_exprs])) }, + vir::expr! { i }, + )]) }, + )])) + }, + ], + ); + + Ok(( + PredicateEncData::Array(PredicateEncDataArray { + snap_data: snap.specifics.expect_array(), + index_access, + unfold_index, + fold_index, + }), + None, + )) +} diff --git a/prusti-encoder/src/encoders/type/kinds/immref.rs b/prusti-encoder/src/encoders/type/kinds/immref.rs index 1c2cc803bb7..a600813bbbd 100644 --- a/prusti-encoder/src/encoders/type/kinds/immref.rs +++ b/prusti-encoder/src/encoders/type/kinds/immref.rs @@ -1,14 +1,51 @@ use crate::encoders::{ - domain::{DomainBuilder, DomainDataImmRef, DomainEnc, DomainEncSpecifics, DomainEncOutputRef}, - predicate::{PredicateBuilder, PredicateEncData, PredicateEncDataImmRef}, + domain::{DomainBuilder, DomainEnc, DomainEncSpecifics, DomainEncOutputRef}, + predicate::{PredicateBuilder, PredicateEncData}, rust_ty_snapshots::RustTySnapshotsEnc, snapshot::SnapshotEncOutput, - GenericEnc, PredicateEnc, + GenericEnc, PredicateEnc, PredicateEncOutputRef, }; use crate::TyConstructorEnc; use prusti_rustc_interface::middle::ty; use task_encoder::{EncodeFullError, TaskEncoder, TaskEncoderDependencies}; -use vir::ToKnownArity; +use vir::{BinaryArity, FunctionIdent, ToKnownArity, UnaryArity}; + +#[derive(Clone, Copy, Debug)] +pub struct DomainDataImmRef<'vir> { + /// Construct domain from a `Ref` value. + pub prim_to_snap: FunctionIdent<'vir, BinaryArity<'vir>>, + /// Function to access the referee. + pub deref_access: FunctionIdent<'vir, UnaryArity<'vir>>, + /// Function to access the snapshot value. + pub value_access: FunctionIdent<'vir, UnaryArity<'vir>>, +} + +impl<'vir> DomainEncSpecifics<'vir> { + #[track_caller] + pub fn expect_immref(self) -> DomainDataImmRef<'vir> { + match self { + Self::ImmRef(data) => data, + _ => panic!("expected immref domain data (got {self:?})"), + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct PredicateEncDataImmRef<'vir> { + pub deref_func: vir::FunctionIdent<'vir, BinaryArity<'vir>>, + pub perm: Option>, + pub snap_data: DomainDataImmRef<'vir>, +} + +impl<'vir> PredicateEncOutputRef<'vir> { + #[track_caller] + pub fn expect_immref(&self) -> PredicateEncDataImmRef<'vir> { + match self.specifics { + PredicateEncData::ImmRef(r) => r, + s => panic!("expected immref predicate data (got {s:?})"), + } + } +} pub(crate) fn domain<'vir>( task_key: ::TaskKey<'vir>, diff --git a/prusti-encoder/src/encoders/type/kinds/mod.rs b/prusti-encoder/src/encoders/type/kinds/mod.rs index c7a6ca96f5a..b940d0c5481 100644 --- a/prusti-encoder/src/encoders/type/kinds/mod.rs +++ b/prusti-encoder/src/encoders/type/kinds/mod.rs @@ -1,6 +1,7 @@ //! Encoding for MIR types, organised by type kind. pub mod adt; +pub mod array; pub mod closure; pub mod immref; pub mod mutref; @@ -10,4 +11,4 @@ pub mod param; pub mod primitive; pub mod str; pub mod tuple; -mod structlike; +pub(super) mod structlike; diff --git a/prusti-encoder/src/encoders/type/kinds/mutref.rs b/prusti-encoder/src/encoders/type/kinds/mutref.rs index 47c15600b66..21ca2e4f0b3 100644 --- a/prusti-encoder/src/encoders/type/kinds/mutref.rs +++ b/prusti-encoder/src/encoders/type/kinds/mutref.rs @@ -1,13 +1,50 @@ use crate::encoders::{ - domain::{DomainBuilder, DomainDataMutRef, DomainEnc, DomainEncSpecifics}, - predicate::{PredicateBuilder, PredicateEncData, PredicateEncDataMutRef}, + domain::{DomainBuilder, DomainEnc, DomainEncSpecifics}, + predicate::{PredicateBuilder, PredicateEncData}, rust_ty_snapshots::RustTySnapshotsEnc, snapshot::SnapshotEncOutput, - PredicateEnc, + PredicateEnc, PredicateEncOutputRef, }; use prusti_rustc_interface::middle::ty; use task_encoder::{EncodeFullError, TaskEncoder, TaskEncoderDependencies}; -use vir::ToKnownArity; +use vir::{BinaryArity, FunctionIdent, ToKnownArity, UnaryArity}; + +#[derive(Clone, Copy, Debug)] +pub struct DomainDataMutRef<'vir> { + /// Construct domain from a `Ref` value. + pub prim_to_snap: FunctionIdent<'vir, BinaryArity<'vir>>, + /// Function to access the referee. + pub deref_access: FunctionIdent<'vir, UnaryArity<'vir>>, + /// Function to access the snapshot value. + pub value_access: FunctionIdent<'vir, UnaryArity<'vir>>, +} + +impl<'vir> DomainEncSpecifics<'vir> { + #[track_caller] + pub fn expect_mutref(self) -> DomainDataMutRef<'vir> { + match self { + Self::MutRef(data) => data, + _ => panic!("expected mutref domain data (got {self:?})"), + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct PredicateEncDataMutRef<'vir> { + pub deref_func: vir::FunctionIdent<'vir, UnaryArity<'vir>>, + pub perm: Option>, + pub snap_data: DomainDataMutRef<'vir>, +} + +impl<'vir> PredicateEncOutputRef<'vir> { + #[track_caller] + pub fn expect_mutref(&self) -> PredicateEncDataMutRef<'vir> { + match self.specifics { + PredicateEncData::MutRef(r) => r, + s => panic!("expected mutref predicate data (got {s:?})"), + } + } +} pub(crate) fn domain<'vir>( task_key: ::TaskKey<'vir>, diff --git a/prusti-encoder/src/encoders/type/kinds/primitive.rs b/prusti-encoder/src/encoders/type/kinds/primitive.rs index ce842a245e6..ff2c7def9b8 100644 --- a/prusti-encoder/src/encoders/type/kinds/primitive.rs +++ b/prusti-encoder/src/encoders/type/kinds/primitive.rs @@ -1,12 +1,82 @@ use crate::encoders::{ - domain::{DomainBuilder, DomainDataPrim, DomainEnc, DomainEncSpecifics}, + domain::{DomainBuilder, DomainEnc, DomainEncSpecifics}, predicate::{PredicateBuilder, PredicateEncData}, snapshot::SnapshotEncOutput, - PredicateEnc, + PredicateEnc, PredicateEncOutputRef, +}; +use prusti_rustc_interface::{ + middle::ty::{self, TyKind}, + target::abi, }; -use prusti_rustc_interface::middle::ty; use task_encoder::{EncodeFullError, TaskEncoder, TaskEncoderDependencies}; -use vir::ToKnownArity; +use vir::{FunctionIdent, ToKnownArity, UnaryArity}; + +#[derive(Clone, Copy, Debug)] +pub struct DomainDataPrim<'vir> { + pub prim_type: vir::Type<'vir>, + /// Snapshot of self as argument. Returns Viper primitive value. + pub snap_to_prim: FunctionIdent<'vir, UnaryArity<'vir>>, + /// Viper primitive value as argument. Returns domain. + pub prim_to_snap: FunctionIdent<'vir, UnaryArity<'vir>>, +} + +impl<'vir> DomainEncSpecifics<'vir> { + #[track_caller] + pub fn expect_primitive(self) -> DomainDataPrim<'vir> { + match self { + Self::Primitive(data) => data, + _ => panic!("expected primitive domain data (got {self:?})"), + } + } +} + +// TODO: PredicateEncDataPrim + +impl<'vir> PredicateEncOutputRef<'vir> { + #[track_caller] + pub fn expect_prim(&self) -> DomainDataPrim<'vir> { + match self.specifics { + PredicateEncData::Primitive(prim) => prim, + s => panic!("expected primitive predicate data (got {s:?})"), + } + } +} + +impl<'vir> DomainDataPrim<'vir> { + pub fn expr_from_bits(&self, ty: ty::Ty<'vir>, value: u128) -> vir::Expr<'vir> { + match *self.prim_type { + vir::TypeData::Bool => { + vir::with_vcx(|vcx| vcx.mk_const_expr(vir::ConstData::Bool(value != 0))) + } + vir::TypeData::Int => { + let (bit_width, signed) = match ty.kind() { + TyKind::Int(ty::IntTy::Isize) => ((std::mem::size_of::() * 8) as u64, true), + TyKind::Int(ty) => (ty.bit_width().unwrap(), true), + TyKind::Uint(ty::UintTy::Usize) => { + ((std::mem::size_of::() * 8) as u64, true) + } + TyKind::Uint(ty) => (ty.bit_width().unwrap(), false), + kind => unreachable!("{kind:?}"), + }; + let size = abi::Size::from_bits(bit_width); + let negative_value = if signed { + let value = size.sign_extend(value); + Some(value).filter(|value| value.is_negative()) + } else { + None + }; + match negative_value { + Some(value) => vir::with_vcx(|vcx| { + let value = vcx.mk_const_expr(vir::ConstData::Int(value.unsigned_abs())); + vcx.mk_unary_op_expr(vir::UnOpKind::Neg, value) + }), + None => vir::with_vcx(|vcx| vcx.mk_const_expr(vir::ConstData::Int(value))), + } + } + ref k => unreachable!("{k:?}"), + } + } +} pub(crate) fn domain<'vir>( task_key: ::TaskKey<'vir>, diff --git a/prusti-encoder/src/encoders/type/kinds/structlike.rs b/prusti-encoder/src/encoders/type/kinds/structlike.rs index 93ab54b41db..6818c04fd9b 100644 --- a/prusti-encoder/src/encoders/type/kinds/structlike.rs +++ b/prusti-encoder/src/encoders/type/kinds/structlike.rs @@ -1,15 +1,51 @@ use crate::encoders::{ - domain::{DomainBuilder, DomainEnc, DomainEncOutputRef, FieldFunctions, FieldTy}, + domain::{DomainBuilder, DomainEnc, DomainEncOutputRef, DomainEncSpecifics, FieldFunctions, FieldTy}, lifted::ty_constructor::TyConstructorEnc, - predicate::PredicateBuilder, + predicate::{PredicateBuilder, PredicateEncData}, rust_ty_predicates::RustTyPredicatesEncOutputRef, snapshot::SnapshotEncOutput, - GenericEnc, PredicateEnc, + GenericEnc, PredicateEnc, PredicateEncOutputRef, }; use prusti_rustc_interface::middle::ty::{ParamTy, TyKind}; use task_encoder::{EncodeFullError, TaskEncoder, TaskEncoderDependencies}; use vir::{vir_format, FunctionIdent, PredicateIdent, ToKnownArity, UnknownArity}; +#[derive(Clone, Copy, Debug)] +pub struct DomainDataStruct<'vir> { + /// Construct domain from snapshots of fields or for primitive types + /// from the single Viper primitive value. + pub field_snaps_to_snap: FunctionIdent<'vir, UnknownArity<'vir>>, + /// Functions to access the fields. + pub field_access: &'vir [FieldFunctions<'vir>], +} + +impl<'vir> DomainEncSpecifics<'vir> { + #[track_caller] + pub fn expect_structlike(self) -> DomainDataStruct<'vir> { + match self { + Self::StructLike(data) => data, + _ => panic!("expected structlike domain data (got {self:?}"), + } + } +} + +impl<'vir> PredicateEncOutputRef<'vir> { + #[track_caller] + pub fn expect_structlike(&self) -> &PredicateEncDataStruct<'vir> { + match &self.specifics { + PredicateEncData::StructLike(data) => data, + s => panic!("expected structlike predicate data (got {s:?}"), + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct PredicateEncDataStruct<'vir> { + pub snap_data: DomainDataStruct<'vir>, + /// Ref to self as argument. Returns Ref to field. + pub ref_to_field_refs: &'vir [FunctionIdent<'vir, UnknownArity<'vir>>], +} + pub fn domain<'vir>( prefix: &str, fields: &[FieldTy<'vir>], diff --git a/prusti-encoder/src/encoders/type/kinds/tuple.rs b/prusti-encoder/src/encoders/type/kinds/tuple.rs index 1c752cec2d1..66164ca2cf6 100644 --- a/prusti-encoder/src/encoders/type/kinds/tuple.rs +++ b/prusti-encoder/src/encoders/type/kinds/tuple.rs @@ -2,7 +2,7 @@ use crate::encoders::{ domain::{ DomainBuilder, DomainDataStruct, DomainEnc, DomainEncOutputRef, DomainEncSpecifics, FieldTy, }, - lifted::ty::{EncodeGenericsAsParamTy, LiftedTyEnc}, + lifted::ty::{EncodeGenericsAsParamTy, LiftedTyEnc, LiftedTyEncTask}, predicate::{PredicateBuilder, PredicateEncData, PredicateEncDataStruct}, rust_ty_predicates::RustTyPredicatesEnc, snapshot::SnapshotEncOutput, @@ -26,7 +26,7 @@ pub(crate) fn domain<'vir>( let generics = params .iter() .map(|ty| { - deps.require_local::>(ty) + deps.require_local::>(LiftedTyEncTask::Ty(ty)) .unwrap() .expect_generic() }) diff --git a/prusti-encoder/src/encoders/type/lifted/aggregate_cast.rs b/prusti-encoder/src/encoders/type/lifted/aggregate_cast.rs index c60ba6f3b51..e13a95f5ca9 100644 --- a/prusti-encoder/src/encoders/type/lifted/aggregate_cast.rs +++ b/prusti-encoder/src/encoders/type/lifted/aggregate_cast.rs @@ -25,6 +25,7 @@ pub struct AggregateSnapArgsCastEncTask<'tcx> { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum AggregateType<'tcx> { + Array, Tuple, Closure { def_id: DefId, @@ -39,6 +40,7 @@ pub enum AggregateType<'tcx> { impl<'tcx> From<&mir::AggregateKind<'tcx>> for AggregateType<'tcx> { fn from(aggregate_kind: &mir::AggregateKind<'tcx>) -> Self { match aggregate_kind { + mir::AggregateKind::Array(elem_ty) => Self::Array, mir::AggregateKind::Tuple => Self::Tuple, mir::AggregateKind::Closure(def_id, args) => Self::Closure { def_id: *def_id, @@ -93,6 +95,18 @@ impl TaskEncoder for AggregateSnapArgsCastEnc { deps.emit_output_ref(task_key.clone(), ())?; vir::with_vcx(|vcx| { let cast_functions: Vec>> = match task_key.aggregate_type { + AggregateType::Array => task_key + .tys + .iter() + .map(|ty| { + let cast_functions = deps + .require_local::>(*ty) + .unwrap(); + cast_functions + .to_generic_cast() + .map(|c| c.map_applicator(|f| f.as_unknown_arity())) + }) + .collect::>(), AggregateType::Tuple => task_key .tys .iter() diff --git a/prusti-encoder/src/encoders/type/lifted/casters.rs b/prusti-encoder/src/encoders/type/lifted/casters.rs index ad9bab50d29..279a320fd33 100644 --- a/prusti-encoder/src/encoders/type/lifted/casters.rs +++ b/prusti-encoder/src/encoders/type/lifted/casters.rs @@ -9,7 +9,7 @@ use crate::encoders::{ }; use super::{ - generic::{LiftedGeneric, LiftedGenericEnc}, + generic::{LiftedGeneric, LiftedGenericEnc, LiftedGenericEncTask}, ty::LiftedTy, }; @@ -249,7 +249,7 @@ impl TaskEncoder for CastersEnc { let ty_params = ty .generics() .into_iter() - .map(|g| deps.require_ref::(*g)) + .map(|g| deps.require_ref::(LiftedGenericEncTask::Param(*g))) .collect::, _>>()?; let make_generic_arg_tys = std::iter::once(self_ty) diff --git a/prusti-encoder/src/encoders/type/lifted/func_app_ty_params.rs b/prusti-encoder/src/encoders/type/lifted/func_app_ty_params.rs index b24f4b6cc51..08715720a3e 100644 --- a/prusti-encoder/src/encoders/type/lifted/func_app_ty_params.rs +++ b/prusti-encoder/src/encoders/type/lifted/func_app_ty_params.rs @@ -5,7 +5,7 @@ use task_encoder::{EncodeFullResult, TaskEncoder}; use super::{ generic::LiftedGeneric, - ty::{EncodeGenericsAsLifted, LiftedTy, LiftedTyEnc}, + ty::{EncodeGenericsAsLifted, LiftedTy, LiftedTyEnc, LiftedTyEncTask}, }; /// Encodes the type parameters to a function application. If we are @@ -44,12 +44,17 @@ impl TaskEncoder for LiftedFuncAppTyParamsEnc { }; let ty_args = ty_args .iter() - .map(|ty| { - deps.require_local::>(*ty) - .unwrap() - }) + .map(|ty| deps.require_local::>(LiftedTyEncTask::Ty(*ty))) + .collect::, _>>()?; + let const_args = substs + .iter() + .filter_map(|arg| arg.as_const()) + .map(|c| deps.require_local::>(LiftedTyEncTask::Const(c))) + .collect::, _>>()?; + let all_args = ty_args.into_iter() + .chain(const_args) .collect::>(); - Ok((vcx.alloc_slice(&ty_args), ())) + Ok((vcx.alloc_slice(&all_args), ())) }) } } diff --git a/prusti-encoder/src/encoders/type/lifted/func_def_ty_params.rs b/prusti-encoder/src/encoders/type/lifted/func_def_ty_params.rs index ebccf5b0b6c..e8ff09db0e0 100644 --- a/prusti-encoder/src/encoders/type/lifted/func_def_ty_params.rs +++ b/prusti-encoder/src/encoders/type/lifted/func_def_ty_params.rs @@ -2,7 +2,7 @@ use prusti_rustc_interface::middle::ty::{self, ParamTy, Ty, TyKind}; use std::collections::HashSet; use task_encoder::{EncodeFullResult, TaskEncoder}; -use super::generic::{LiftedGeneric, LiftedGenericEnc}; +use super::generic::{LiftedGeneric, LiftedGenericEnc, LiftedGenericEncTask}; /// Encodes the type parameters of a (possibly monomorphised) function /// definition. It takes as input a type substitution and returns the list of @@ -36,9 +36,17 @@ impl TaskEncoder for LiftedTyParamsEnc { .filter_map(|arg| arg.as_type()) .flat_map(extract_ty_params); let ty_args = unique(ty_args) - .map(|ty| deps.require_ref::(ty).unwrap()) + .map(|ty| deps.require_ref::(LiftedGenericEncTask::Param(ty))) + .collect::, _>>()?; + let const_args = task_key + .iter() + .filter_map(|arg| arg.as_const()) + .map(|c| deps.require_ref::(LiftedGenericEncTask::Const(c))) + .collect::, _>>()?; + let all_args = ty_args.into_iter() + .chain(const_args) .collect::>(); - Ok((vcx.alloc_slice(&ty_args), ())) + Ok((vcx.alloc_slice(&all_args), ())) }) } } diff --git a/prusti-encoder/src/encoders/type/lifted/generic.rs b/prusti-encoder/src/encoders/type/lifted/generic.rs index b5e4965585e..54703682451 100644 --- a/prusti-encoder/src/encoders/type/lifted/generic.rs +++ b/prusti-encoder/src/encoders/type/lifted/generic.rs @@ -28,12 +28,18 @@ impl<'vir> LiftedGeneric<'vir> { impl<'vir> OutputRefAny for LiftedGeneric<'vir> {} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum LiftedGenericEncTask<'vir> { + Param(ty::ParamTy), + Const(ty::Const<'vir>), +} + pub struct LiftedGenericEnc; impl TaskEncoder for LiftedGenericEnc { task_encoder::encoder_cache!(LiftedGenericEnc); - type TaskDescription<'tcx> = ty::ParamTy; + type TaskDescription<'tcx> = LiftedGenericEncTask<'tcx>; type TaskKey<'tcx> = Self::TaskDescription<'tcx>; @@ -52,11 +58,21 @@ impl TaskEncoder for LiftedGenericEnc { deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, ) -> EncodeFullResult<'vir, Self> { with_vcx(|vcx| { - let output_ref = vcx.mk_local_decl( - vcx.alloc_str(task_key.name.as_str()), - deps.require_ref::(())?.type_snapshot, - ); - deps.emit_output_ref(*task_key, LiftedGeneric(output_ref))?; + let output_ref = LiftedGeneric(match task_key { + LiftedGenericEncTask::Param(param) => vcx.mk_local_decl( + vcx.alloc_str(param.name.as_str()), + deps.require_ref::(())?.type_snapshot, + ), + LiftedGenericEncTask::Const(c) => match c.kind() { + ty::ConstKind::Param(param) => vcx.mk_local_decl( + vcx.alloc_str(param.name.as_str()), + deps.require_ref::(())?.type_snapshot, + ), + _ => todo!("lifted generic const {c:?}") + } + // LiftedGenericEncTask::Const(c) => todo!("lifted generic const {c:?}"), + }); + deps.emit_output_ref(*task_key, output_ref)?; Ok(((), ())) }) } diff --git a/prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs b/prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs index 00a83a4c27c..1e934c77d54 100644 --- a/prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs +++ b/prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs @@ -13,7 +13,7 @@ use super::{ Casters, CastersEnc, ImpureCastStmts, MakeGenericCastFunction, }, generic::LiftedGeneric, - ty::{EncodeGenericsAsLifted, LiftedTy, LiftedTyEnc}, + ty::{EncodeGenericsAsLifted, LiftedTy, LiftedTyEnc, LiftedTyEncTask}, }; /// Generates Viper functions to cast between generic and non-generic Viper @@ -89,7 +89,7 @@ where let ty_args = args .iter() .map(|a| { - deps.require_local::>(*a) + deps.require_local::>(LiftedTyEncTask::Ty(*a)) .unwrap() }) .collect::>(); diff --git a/prusti-encoder/src/encoders/type/lifted/ty.rs b/prusti-encoder/src/encoders/type/lifted/ty.rs index 7b83434afb1..3cc201e912a 100644 --- a/prusti-encoder/src/encoders/type/lifted/ty.rs +++ b/prusti-encoder/src/encoders/type/lifted/ty.rs @@ -6,12 +6,13 @@ use vir::{with_vcx, FunctionIdent, UnknownArity}; use crate::encoders::{ lifted::{ - generic::{LiftedGeneric, LiftedGenericEnc}, - ty_constructor::TyConstructorEnc, + casters::CastTypePure, generic::{LiftedGeneric, LiftedGenericEnc}, rust_ty_cast::RustTyCastersEnc, ty_constructor::TyConstructorEnc }, - most_generic_ty::extract_type_params, + most_generic_ty::extract_type_params, rust_ty_snapshots::RustTySnapshotsEnc, }; +use super::generic::LiftedGenericEncTask; + /// Representation of a Rust type as a Viper expression. Generics are /// represented with values of type `T`. In the usual case `T` should be /// [`LiftedGeneric`], but in some cases alternative types are useful (see @@ -29,6 +30,8 @@ pub enum LiftedTy<'vir, T> { /// Arguments to the type constructor e.g. `T` in `Option` args: &'vir [LiftedTy<'vir, T>], }, + + Expr(vir::Expr<'vir>), } impl<'vir, 'tcx, T: Copy> LiftedTy<'vir, T> { @@ -50,6 +53,7 @@ impl<'vir, 'tcx, T: Copy> LiftedTy<'vir, T> { } } LiftedTy::Generic(g) => LiftedTy::Generic(f(*g)), + LiftedTy::Expr(e) => LiftedTy::Expr(e), } } @@ -66,6 +70,7 @@ impl<'vir, 'tcx, Curr, Next> LiftedTy<'vir, vir::ExprGen<'vir, Curr, Next>> { match self { LiftedTy::Generic(g) => vec![*g], LiftedTy::Instantiated { args, .. } => args.iter().map(|a| a.expr(vcx)).collect(), + LiftedTy::Expr(..) => Vec::new(), } } @@ -76,6 +81,7 @@ impl<'vir, 'tcx, Curr, Next> LiftedTy<'vir, vir::ExprGen<'vir, Curr, Next>> { ty_constructor, args, } => ty_constructor.apply(vcx, &args.iter().map(|a| a.expr(vcx)).collect::>()), + LiftedTy::Expr(e) => unsafe { std::mem::transmute(e.lift::()) }, // TODO: .... } } } @@ -99,6 +105,12 @@ impl<'vir, 'tcx> LiftedTy<'vir, LiftedGeneric<'vir>> { pub struct EncodeGenericsAsLifted; pub struct EncodeGenericsAsParamTy; +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum LiftedTyEncTask<'vir> { + Ty(ty::Ty<'vir>), + Const(ty::Const<'vir>), +} + /// Encodes the Viper representation of a Rust type ([`LiftedTy`]). The type /// parameter `T` determines how Rust generic types are encoded; different /// encoder implementations are used for different types of generic types. The @@ -112,7 +124,7 @@ pub struct LiftedTyEnc(PhantomData); impl TaskEncoder for LiftedTyEnc { task_encoder::encoder_cache!(LiftedTyEnc); - type TaskDescription<'tcx> = ty::Ty<'tcx>; + type TaskDescription<'tcx> = LiftedTyEncTask<'tcx>; type TaskKey<'tcx> = Self::TaskDescription<'tcx>; @@ -124,6 +136,73 @@ impl TaskEncoder for LiftedTyEnc { *task } + /* + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(*task_key, ())?; + with_vcx(|vcx| match task_key { + LiftedTyEncTask::Const(c) => match c.kind() { + ty::ConstKind::Value(ty, value) => { + let kind = deps + .require_local::(ty)? + .generic_snapshot + .specifics; + let prim = kind.expect_primitive(); + let val = value.try_to_scalar_int().unwrap() + .to_bits_unchecked(); // TODO: less unwrapping + let val = prim.expr_from_bits(ty, val); + Ok((LiftedTy::Generic(vir::with_vcx(|vcx| prim.prim_to_snap.apply(vcx, [val]))), ())) + //Ok((LiftedTy::Expr(vir::with_vcx(|vcx| prim.prim_to_snap.apply(vcx, [val]))), ())) + } + _ => todo!("encode const {:?}", c.kind()), + } + _ => { + let result = deps.require_local::>(*task_key)?; + let result = result.map(vcx, &mut |g| { + deps.require_ref::(LiftedGenericEncTask::Param(g)).unwrap() + }); + Ok((result, ())) + } + /* + LiftedTyEncTask::Ty(ty) => { + if let TyKind::Param(p) = ty.kind() { + return Ok((LiftedTy::Generic(*p), ())); + } + let (ty_constructor, args) = extract_type_params(vcx.tcx(), *ty); + let ty_constructor = deps + .require_ref::(ty_constructor)? + .ty_constructor; + let args = args + .into_iter() + .map(|ty| deps.require_local::(LiftedTyEncTask::Ty(ty)).unwrap()) + .collect::>(); + Ok((LiftedTy::Instantiated { + ty_constructor, + args: vcx.alloc_slice(&args), + }, ())) + } + LiftedTyEncTask::Const(c) => todo!(),/* + LiftedTyEncTask::Const(c) => match c.kind() { + ty::ConstKind::Value(ty, value) => { + let kind = deps + .require_local::(ty)? + .generic_snapshot + .specifics; + let prim = kind.expect_primitive(); + let val = value.try_to_scalar_int().unwrap() + .to_bits_unchecked(); // TODO: less unwrapping + let val = prim.expr_from_bits(ty, val); + Ok((LiftedTy::Generic(vir::with_vcx(|vcx| prim.prim_to_snap.apply(vcx, [val]))), ())) + //Ok((LiftedTy::Expr(vir::with_vcx(|vcx| prim.prim_to_snap.apply(vcx, [val]))), ())) + } + _ => todo!("encode const {:?}", c.kind()), + }*/ + */ + }) + } + */ fn do_encode_full<'vir>( task_key: &Self::TaskKey<'vir>, deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, @@ -132,7 +211,7 @@ impl TaskEncoder for LiftedTyEnc { with_vcx(|vcx| { let result = deps.require_local::>(*task_key)?; let result = result.map(vcx, &mut |g| { - deps.require_ref::(g).unwrap() + deps.require_ref::(LiftedGenericEncTask::Param(g)).unwrap() }); Ok((result, ())) }) @@ -144,7 +223,7 @@ impl TaskEncoder for LiftedTyEnc { impl TaskEncoder for LiftedTyEnc { task_encoder::encoder_cache!(LiftedTyEnc); - type TaskDescription<'tcx> = ty::Ty<'tcx>; + type TaskDescription<'tcx> = LiftedTyEncTask<'tcx>; type TaskKey<'tcx> = Self::TaskDescription<'tcx>; @@ -161,23 +240,41 @@ impl TaskEncoder for LiftedTyEnc { deps: &mut task_encoder::TaskEncoderDependencies<'vir, Self>, ) -> EncodeFullResult<'vir, Self> { deps.emit_output_ref(*task_key, ())?; - with_vcx(|vcx| { - if let TyKind::Param(p) = task_key.kind() { - return Ok((LiftedTy::Generic(*p), ())); + with_vcx(|vcx| match task_key { + LiftedTyEncTask::Ty(ty) => { + if let TyKind::Param(p) = ty.kind() { + return Ok((LiftedTy::Generic(*p), ())); + } + let (ty_constructor, args) = extract_type_params(vcx.tcx(), *ty); + let ty_constructor = deps + .require_ref::(ty_constructor)? + .ty_constructor; + let args = args + .into_iter() + .map(|ty| deps.require_local::(LiftedTyEncTask::Ty(ty)).unwrap()) + .collect::>(); + Ok((LiftedTy::Instantiated { + ty_constructor, + args: vcx.alloc_slice(&args), + }, ())) + } + LiftedTyEncTask::Const(c) => match c.kind() { + ty::ConstKind::Value(ty, value) => { + let kind = deps + .require_local::(ty)? + .generic_snapshot + .specifics; + let prim = kind.expect_primitive(); + let val = value.try_to_scalar_int().unwrap() + .to_bits_unchecked(); // TODO: less unwrapping + let val = prim.expr_from_bits(ty, val); + let snap = prim.prim_to_snap.apply(vcx, [val]); + let cast = deps.require_local::>(ty)?; + let snap = cast.cast_to_generic_if_necessary(vcx, snap); + Ok((LiftedTy::Expr(snap), ())) + } + _ => todo!("encode const {:?}", c.kind()), } - let (ty_constructor, args) = extract_type_params(vcx.tcx(), *task_key); - let ty_constructor = deps - .require_ref::(ty_constructor)? - .ty_constructor; - let args = args - .into_iter() - .map(|ty| deps.require_local::(ty).unwrap()) - .collect::>(); - let result = LiftedTy::Instantiated { - ty_constructor, - args: vcx.alloc_slice(&args), - }; - Ok((result, ())) }) } } diff --git a/prusti-encoder/src/encoders/type/most_generic_ty.rs b/prusti-encoder/src/encoders/type/most_generic_ty.rs index 5999d8ce9f3..715c7d7692d 100644 --- a/prusti-encoder/src/encoders/type/most_generic_ty.rs +++ b/prusti-encoder/src/encoders/type/most_generic_ty.rs @@ -39,6 +39,7 @@ pub fn get_vir_base_name_kind<'tcx>(kind: &ty::TyKind<'tcx>, vcx: &vir::VirCtxt< TyKind::RawPtr(_, ty::Mutability::Not) => String::from("RawPtr_immutable"), TyKind::RawPtr(_, ty::Mutability::Mut) => String::from("RawPtr_mutable"), TyKind::Param(_) => String::from("Param"), + TyKind::Array(_, _) => format!("Array"), // , get_vir_base_name_kind(elem_ty.kind(), vcx)), TyKind::Closure(def_id, _) => { let def_key = vcx.tcx().def_key(def_id); match def_key.disambiguated_data.data { diff --git a/prusti-encoder/src/encoders/type/predicate.rs b/prusti-encoder/src/encoders/type/predicate.rs index e611e1f562d..c0c54dffbb5 100644 --- a/prusti-encoder/src/encoders/type/predicate.rs +++ b/prusti-encoder/src/encoders/type/predicate.rs @@ -1,20 +1,22 @@ -use prusti_rustc_interface::{ - middle::ty::{self, TyKind}, - target::abi, -}; +use prusti_rustc_interface::middle::ty::{self, TyKind}; use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; use vir::{ - BinaryArity, CallableIdent, FunctionIdent, MethodIdent, NullaryArity, PredicateIdent, TypeData, - UnaryArity, UnknownArity, VirCtxt, + CallableIdent, FunctionIdent, MethodIdent, NullaryArity, PredicateIdent, TypeData, + UnknownArity, VirCtxt, }; use crate::encoders::GenericEnc; use super::{ - domain::{DomainDataImmRef, DomainDataMutRef, DomainDataPrim, DomainDataStruct}, - lifted::{generic::LiftedGeneric, ty::LiftedTy}, - most_generic_ty::{get_vir_base_name_kind, MostGenericTy}, - snapshot::SnapshotEnc, + kinds::primitive::DomainDataPrim, lifted::{generic::LiftedGeneric, ty::LiftedTy}, most_generic_ty::{get_vir_base_name_kind, MostGenericTy}, snapshot::SnapshotEnc +}; + +pub use super::kinds::{ + adt::PredicateEncDataEnum, + array::PredicateEncDataArray, + immref::PredicateEncDataImmRef, + mutref::PredicateEncDataMutRef, + structlike::PredicateEncDataStruct, }; /// Takes a `MostGenericTy` and returns various Viper predicates and functions for @@ -26,46 +28,10 @@ pub enum PredicateEncError { // UnsupportedType, } -#[derive(Clone, Copy, Debug)] -pub struct PredicateEncDataStruct<'vir> { - pub snap_data: DomainDataStruct<'vir>, - /// Ref to self as argument. Returns Ref to field. - pub ref_to_field_refs: &'vir [FunctionIdent<'vir, UnknownArity<'vir>>], -} - -#[derive(Clone, Copy, Debug)] -pub struct PredicateEncDataEnum<'vir> { - pub discr: FunctionIdent<'vir, UnaryArity<'vir>>, - pub discr_prim: DomainDataPrim<'vir>, - //pub discr_bounds: DiscrBounds<'vir>, - // pub snap_to_discr_snap: FunctionIdent<'vir, UnaryArity<'vir>>, - pub variants: &'vir [PredicateEncDataVariant<'vir>], -} -#[derive(Clone, Copy, Debug)] -pub struct PredicateEncDataVariant<'vir> { - pub predicate: PredicateIdent<'vir, UnknownArity<'vir>>, - pub vid: abi::VariantIdx, - pub discr: vir::Expr<'vir>, - pub fields: PredicateEncDataStruct<'vir>, -} - -#[derive(Clone, Copy, Debug)] -pub struct PredicateEncDataImmRef<'vir> { - pub deref_func: vir::FunctionIdent<'vir, BinaryArity<'vir>>, - pub perm: Option>, - pub snap_data: DomainDataImmRef<'vir>, -} - -#[derive(Clone, Copy, Debug)] -pub struct PredicateEncDataMutRef<'vir> { - pub deref_func: vir::FunctionIdent<'vir, UnaryArity<'vir>>, - pub perm: Option>, - pub snap_data: DomainDataMutRef<'vir>, -} - #[derive(Clone, Copy, Debug)] pub enum PredicateEncData<'vir> { Never, + Array(PredicateEncDataArray<'vir>), Primitive(DomainDataPrim<'vir>), // structs, tuples Trusted, @@ -118,88 +84,6 @@ impl<'vir> PredicateEncOutputRef<'vir> { args.extend(instantiated_ty.arg_exprs(vcx)); vcx.alloc_slice(&args) } - - #[track_caller] - pub fn expect_prim(&self) -> DomainDataPrim<'vir> { - match self.specifics { - PredicateEncData::Primitive(prim) => prim, - _ => panic!("expected primitive type"), - } - } - #[track_caller] - pub fn expect_immref(&self) -> PredicateEncDataImmRef<'vir> { - match self.specifics { - PredicateEncData::ImmRef(r) => r, - s => panic!("expected immref type ({s:?})"), - } - } - #[track_caller] - pub fn expect_mutref(&self) -> PredicateEncDataMutRef<'vir> { - match self.specifics { - PredicateEncData::MutRef(r) => r, - s => panic!("expected mutref type ({s:?})"), - } - } - pub fn get_structlike(&self) -> Option<&PredicateEncDataStruct<'vir>> { - match &self.specifics { - PredicateEncData::StructLike(data) => Some(data), - _ => None, - } - } - #[track_caller] - pub fn expect_structlike(&self) -> &PredicateEncDataStruct<'vir> { - self.get_structlike().expect("expected structlike type") - } - pub fn get_enumlike(&self) -> Option<&Option>> { - match &self.specifics { - PredicateEncData::EnumLike(e) => Some(e), - _ => None, - } - } - #[track_caller] - pub fn expect_enumlike(&self) -> Option<&PredicateEncDataEnum<'vir>> { - self.get_enumlike() - .expect("expected enumlike type") - .as_ref() - } - pub fn get_variant_any(&self, vid: abi::VariantIdx) -> &PredicateEncDataStruct<'vir> { - match &self.specifics { - PredicateEncData::StructLike(s) => { - assert_eq!(vid, abi::FIRST_VARIANT); - s - } - PredicateEncData::EnumLike(e) => &e.as_ref().unwrap().variants[vid.as_usize()].fields, - _ => panic!("expected structlike or enumlike type"), - } - } - - #[track_caller] - pub fn expect_variant(&self, vid: abi::VariantIdx) -> &PredicateEncDataVariant<'vir> { - match &self.specifics { - PredicateEncData::EnumLike(e) => &e.as_ref().unwrap().variants[vid.as_usize()], - _ => panic!("expected enum type"), - } - } - #[track_caller] - pub fn expect_pred_variant_opt( - &self, - vid: Option, - ) -> PredicateIdent<'vir, UnknownArity<'vir>> { - vid.map(|vid| self.expect_variant(vid).predicate) - .unwrap_or(self.ref_to_pred) - } - #[track_caller] - pub fn expect_variant_opt( - &self, - vid: Option, - ) -> &PredicateEncDataStruct<'vir> { - match vid { - None => self.expect_structlike(), - Some(vid) => { - &self.expect_enumlike().expect("empty enum").variants[vid.as_usize()].fields - } - } - } } pub(crate) struct PredicateBuilder<'vir> { @@ -384,7 +268,7 @@ impl<'vir> PredicateBuilder<'vir> { unreachable_to_snap: self.unreachable_to_snap.unwrap().1, function_snap: self.function_snap.unwrap(), ref_to_field_refs: self.functions, - method_assign: self.methods[0], + methods: self.methods, } } } @@ -397,7 +281,7 @@ pub struct PredicateEncOutput<'vir> { pub unreachable_to_snap: vir::Function<'vir>, pub function_snap: vir::Function<'vir>, pub ref_to_field_refs: Vec>, - pub method_assign: vir::Method<'vir>, + pub methods: Vec>, } impl TaskEncoder for PredicateEnc { @@ -463,7 +347,7 @@ impl TaskEncoder for PredicateEnc { unreachable_to_snap: dep.unreachable_to_snap, function_snap: dep.ref_to_snap, ref_to_field_refs: vec![], - method_assign, + methods: vec![method_assign], }, (), )) @@ -567,6 +451,14 @@ impl TaskEncoder for PredicateEnc { | TyKind::Float(_) => { super::kinds::primitive::predicate(*task_key, snap.clone(), deps, &mut builder)? } + TyKind::Array(..) => super::kinds::array::predicate( + *task_key, + snap.clone(), + deps, + &generic_decls, + &generic_exprs, + &mut builder, + )?, TyKind::Adt(..) => super::kinds::adt::predicate( *task_key, snap.clone(), diff --git a/prusti-encoder/src/encoders/type/rust_ty_predicates.rs b/prusti-encoder/src/encoders/type/rust_ty_predicates.rs index 88510e6fc29..61b5d4d9264 100644 --- a/prusti-encoder/src/encoders/type/rust_ty_predicates.rs +++ b/prusti-encoder/src/encoders/type/rust_ty_predicates.rs @@ -6,7 +6,7 @@ use crate::encoders::{PredicateEnc, PredicateEncOutputRef}; use super::{ lifted::{ generic::LiftedGeneric, - ty::{EncodeGenericsAsLifted, LiftedTy, LiftedTyEnc}, + ty::{EncodeGenericsAsLifted, LiftedTy, LiftedTyEnc, LiftedTyEncTask}, }, most_generic_ty::extract_type_params, }; @@ -149,7 +149,7 @@ impl TaskEncoder for RustTyPredicatesEnc { None }; */ - let ty = deps.require_local::>(*task_key)?; + let ty = deps.require_local::>(LiftedTyEncTask::Ty(*task_key))?; deps.emit_output_ref( *task_key, RustTyPredicatesEncOutputRef { diff --git a/prusti-encoder/src/encoders/type/snapshot.rs b/prusti-encoder/src/encoders/type/snapshot.rs index 9fa79fea0cb..fb315a868c3 100644 --- a/prusti-encoder/src/encoders/type/snapshot.rs +++ b/prusti-encoder/src/encoders/type/snapshot.rs @@ -2,7 +2,7 @@ use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; use super::{ domain::{DomainEnc, DomainEncSpecifics}, - lifted::generic::{LiftedGeneric, LiftedGenericEnc}, + lifted::generic::{LiftedGeneric, LiftedGenericEnc, LiftedGenericEncTask}, most_generic_ty::MostGenericTy, }; @@ -49,7 +49,7 @@ impl TaskEncoder for SnapshotEnc { let generics = vcx.alloc_slice( &ty.generics() .into_iter() - .map(|g| deps.require_ref::(*g).unwrap()) + .map(|g| deps.require_ref::(LiftedGenericEncTask::Param(*g)).unwrap()) .collect::>(), ); Ok(( diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index 5b6d27feedf..68fca9bef93 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -107,6 +107,15 @@ pub fn test_entrypoint<'tcx>( viper_code.push_str(&format!("{:?}\n", output.param_snapshot)); program_domains.push(output.type_snapshot); program_domains.push(output.param_snapshot); + /* + // TODO: should these be emitted by PredicateEnc? + viper_code.push_str(&format!("{:?}\n", output.ref_to_pred)); + viper_code.push_str(&format!("{:?}\n", output.ref_to_snap)); + viper_code.push_str(&format!("{:?}\n", output.unreachable_to_snap)); + program_predicates.push(output.ref_to_pred); + program_functions.push(output.ref_to_snap); + program_functions.push(output.unreachable_to_snap); + */ } header(&mut viper_code, "pure generic casts"); @@ -155,8 +164,10 @@ pub fn test_entrypoint<'tcx>( viper_code.push_str(&format!("{:?}\n", pred)); program_predicates.push(pred); } - viper_code.push_str(&format!("{:?}\n", output.method_assign)); - program_methods.push(output.method_assign); + for method in output.methods { + viper_code.push_str(&format!("{:?}\n", method)); + program_methods.push(method); + } } if std::env::var("LOCAL_TESTING").is_ok() { diff --git a/prusti-tests/tests/v2/pass/generics/const.rs b/prusti-tests/tests/v2/pass/generics/const.rs new file mode 100644 index 00000000000..48c30e5b9d5 --- /dev/null +++ b/prusti-tests/tests/v2/pass/generics/const.rs @@ -0,0 +1,19 @@ +use prusti_contracts::*; + +// struct Foo { +// arr: [i32; N], +// } + +#[requires(N > 0)] +fn foo() { + assert!(N > 0); +} +/* +#[requires(N > 10)] +fn bar() { + foo::(); +} +*/ +fn main() { + foo::<0, i32>(); +} diff --git a/prusti-tests/tests/v2/pass/generics/duplicate.rs b/prusti-tests/tests/v2/pass/generics/duplicate.rs new file mode 100644 index 00000000000..1fa36a6a79e --- /dev/null +++ b/prusti-tests/tests/v2/pass/generics/duplicate.rs @@ -0,0 +1,12 @@ +struct Foo { + f: T, +} + +fn foo(a: Foo, b: Foo>) {} + +fn main() { + foo( + Foo { f: 0 }, + Foo { f: Foo { f: 1 } }, + ); +} diff --git a/prusti-tests/tests/v2/pass/types/array.rs b/prusti-tests/tests/v2/pass/types/array.rs new file mode 100644 index 00000000000..1bfb46e98d2 --- /dev/null +++ b/prusti-tests/tests/v2/pass/types/array.rs @@ -0,0 +1,9 @@ +use prusti_contracts::*; + +#[requires(x[2] > 10)] +fn test1(x: [i32; 3]) { + assert!(x[2] > 0); +} + +#[trusted] +fn main() {} diff --git a/prusti-viper/src/lib.rs b/prusti-viper/src/lib.rs index 72f6e86e8bd..ef5fedad220 100644 --- a/prusti-viper/src/lib.rs +++ b/prusti-viper/src/lib.rs @@ -192,6 +192,7 @@ impl<'vir, 'v> ToViper<'vir, 'v> for vir::BinOp<'vir> { vir::BinOpKind::DivRational => ctx.ast.perm_div(lhs, rhs), // TODO: position vir::BinOpKind::Mod => ctx.ast.mod_with_pos(lhs, rhs, pos), vir::BinOpKind::Implies => ctx.ast.implies_with_pos(lhs, rhs, pos), + vir::BinOpKind::SeqIndex => ctx.ast.seq_index(lhs, rhs), // TODO: position } } } @@ -331,6 +332,7 @@ impl<'vir, 'v> ToViper<'vir, 'v> for vir::Expr<'vir> { vir::ExprKindData::Result(ty) => ctx .ast .result_with_pos(ty.to_viper_no_pos(ctx), ctx.span_to_pos(self.span)), + vir::ExprKindData::SeqLiteral(v) => v.to_viper_with_span(ctx, self.span), vir::ExprKindData::Ternary(v) => v.to_viper_with_span(ctx, self.span), vir::ExprKindData::Unfolding(v) => v.to_viper_with_span(ctx, self.span), vir::ExprKindData::UnOp(v) => v.to_viper_with_span(ctx, self.span), @@ -678,6 +680,19 @@ impl<'vir, 'v> ToViper<'vir, 'v> for vir::PureAssign<'vir> { } } +impl<'vir, 'v> ToViper<'vir, 'v> for vir::SeqLiteral<'vir> { + type Output = viper::Expr<'v>; + fn to_viper(&self, ctx: &ToViperContext<'vir, 'v>, _pos: Position) -> Self::Output { + ctx.ast.explicit_seq( + &self + .values + .iter() + .map(|v| v.to_viper_no_pos(ctx)) + .collect::>(), + ) + } +} + impl<'vir, 'v> ToViper<'vir, 'v> for vir::Stmt<'vir> { type Output = viper::Stmt<'v>; fn to_viper(&self, ctx: &ToViperContext<'vir, 'v>, _pos: Position) -> Self::Output { @@ -840,6 +855,7 @@ impl<'vir, 'v> ToViper<'vir, 'v> for vir::Type<'vir> { } vir::TypeData::Ref => ctx.ast.ref_type(), vir::TypeData::Perm => ctx.ast.perm_type(), + vir::TypeData::Seq(elem_ty) => ctx.ast.seq_type(elem_ty.to_viper_no_pos(ctx)), //vir::TypeData::Predicate, // The type of a predicate application //vir::TypeData::Unsupported(UnsupportedType<'vir>) other => unimplemented!("{:?}", other), @@ -855,6 +871,7 @@ impl<'vir, 'v> ToViper<'vir, 'v> for vir::UnOp<'vir> { match self.kind { vir::UnOpKind::Neg => ctx.ast.minus_with_pos(expr, pos), vir::UnOpKind::Not => ctx.ast.not_with_pos(expr, pos), + vir::UnOpKind::SeqLen => ctx.ast.seq_length(expr), } } } diff --git a/vir/src/data.rs b/vir/src/data.rs index de0efca5bff..b9bbab3b03b 100644 --- a/vir/src/data.rs +++ b/vir/src/data.rs @@ -29,6 +29,7 @@ pub struct LocalDeclData<'vir> { pub enum UnOpKind { Neg, Not, + SeqLen, } impl From for UnOpKind { fn from(value: mir::UnOp) -> Self { @@ -62,6 +63,7 @@ pub enum BinOpKind { Div, DivRational, Mod, + SeqIndex, // ... } impl From for BinOpKind { @@ -137,6 +139,9 @@ pub enum TypeData<'vir> { Ref, // TODO: typed references ? Perm, Predicate, // The type of a predicate application + Seq( + #[serde(with = "crate::serde::serde_ref")] Type<'vir>, + ), Unsupported(UnsupportedType<'vir>), } @@ -227,6 +232,7 @@ pub type PredicateAppData<'vir> = crate::gendata::PredicateAppGenData<'vir, !, ! pub type PredicateData<'vir> = crate::gendata::PredicateGenData<'vir, !, !>; pub type ProgramData<'vir> = crate::gendata::ProgramGenData<'vir, !, !>; pub type PureAssignData<'vir> = crate::gendata::PureAssignGenData<'vir, !, !>; +pub type SeqLiteralData<'vir> = crate::gendata::SeqLiteralGenData<'vir, !, !>; pub type StmtData<'vir> = crate::gendata::StmtGenData<'vir, !, !>; pub type StmtKindData<'vir> = crate::gendata::StmtKindGenData<'vir, !, !>; pub type TerminatorStmtData<'vir> = crate::gendata::TerminatorStmtGenData<'vir, !, !>; diff --git a/vir/src/debug.rs b/vir/src/debug.rs index 2c2b850653c..8ad2d9ba0bb 100644 --- a/vir/src/debug.rs +++ b/vir/src/debug.rs @@ -47,6 +47,13 @@ impl<'vir, Curr, Next> Debug for AccFieldGenData<'vir, Curr, Next> { impl<'vir, Curr, Next> Debug for BinOpGenData<'vir, Curr, Next> { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + if self.kind == BinOpKind::SeqIndex { + write!(f, "(")?; + self.lhs.fmt(f)?; + write!(f, ")[")?; + self.rhs.fmt(f)?; + return write!(f, "]"); + } write!(f, "(")?; self.lhs.fmt(f)?; write!( @@ -68,6 +75,7 @@ impl<'vir, Curr, Next> Debug for BinOpGenData<'vir, Curr, Next> { BinOpKind::Div => "\\", BinOpKind::DivRational => "/", BinOpKind::Mod => "%", + BinOpKind::SeqIndex => unreachable!(), } )?; self.rhs.fmt(f)?; @@ -162,6 +170,7 @@ impl<'vir, Curr, Next> Debug for ExprKindGenData<'vir, Curr, Next> { Self::Local(e) => e.fmt(f), Self::Old(e) => e.fmt(f), Self::PredicateApp(e) => e.fmt(f), + Self::SeqLiteral(e) => e.fmt(f), Self::Wand(e) => e.fmt(f), Self::Ternary(e) => e.fmt(f), Self::UnOp(e) => e.fmt(f), @@ -332,6 +341,14 @@ impl<'vir, Curr, Next> Debug for PredicateAppGenData<'vir, Curr, Next> { } } +impl<'vir, Curr, Next> Debug for SeqLiteralGenData<'vir, Curr, Next> { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "Seq(")?; + fmt_comma_sep(f, self.values)?; + write!(f, ")") + } +} + impl<'vir, Curr, Next> Debug for StmtGenData<'vir, Curr, Next> { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { if let Some(span) = self.span { @@ -483,6 +500,7 @@ impl<'vir> Debug for TypeData<'vir> { Self::Ref => write!(f, "Ref"), Self::Perm => write!(f, "Perm"), Self::Predicate => write!(f, "Predicate"), + Self::Seq(ty) => write!(f, "Seq[{ty:?}]"), Self::Unsupported(u) => u.fmt(f), } } @@ -504,12 +522,17 @@ impl<'vir, Curr, Next> Debug for UnOpGenData<'vir, Curr, Next> { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!( f, - "{}({:?})", + "{}({:?}){}", match self.kind { UnOpKind::Neg => "-", UnOpKind::Not => "!", + UnOpKind::SeqLen => "|", + }, + self.expr, + match self.kind { + UnOpKind::SeqLen => "|", + _ => "", }, - self.expr ) } } diff --git a/vir/src/gendata.rs b/vir/src/gendata.rs index 5158d34a976..c78e8b5e098 100644 --- a/vir/src/gendata.rs +++ b/vir/src/gendata.rs @@ -11,6 +11,16 @@ pub struct UnOpGenData<'vir, Curr, Next> { pub expr: ExprGen<'vir, Curr, Next>, } +impl<'vir, Curr, Next> UnOpGenData<'vir, Curr, Next> { + pub fn ty(&self) -> Type<'vir> { + match self.kind { + UnOpKind::Neg + | UnOpKind::Not => self.expr.ty(), + UnOpKind::SeqLen => &TypeData::Int, + } + } +} + #[derive(VirHash, VirReify, VirSerde)] pub struct BinOpGenData<'vir, Curr, Next> { #[vir(reify_pass)] @@ -33,6 +43,10 @@ impl<'vir, Curr, Next> BinOpGenData<'vir, Curr, Next> { self.lhs.ty() } BinOpKind::DivRational => &TypeData::Perm, + BinOpKind::SeqIndex => match self.lhs.ty() { + TypeData::Seq(elem_ty) => elem_ty, + _ => unreachable!(), + }, } } } @@ -114,6 +128,13 @@ pub struct WandGenData<'vir, Curr, Next> { pub rhs: ExprGen<'vir, Curr, Next>, } +#[derive(VirHash, VirReify, VirSerde)] +pub struct SeqLiteralGenData<'vir, Curr, Next> { + pub values: &'vir [ExprGen<'vir, Curr, Next>], + #[vir(reify_pass, is_ref)] + pub ty: Type<'vir>, +} + /* // TODO: something like this would be a cleaner solution for ExprGenData's // generic; when tested, this runs into an infinite loop in rustc ...? @@ -165,7 +186,8 @@ pub enum ExprKindGenData<'vir, Curr: 'vir, Next: 'vir> { // perm ops? // container ops? // map ops? - // sequence, map, set, multiset literals + // map, set, multiset literals + SeqLiteral(SeqLiteralGen<'vir, Curr, Next>), Ternary(TernaryGen<'vir, Curr, Next>), Exists(ExistsGen<'vir, Curr, Next>), Forall(ForallGen<'vir, Curr, Next>), @@ -193,8 +215,9 @@ impl<'vir, Curr, Next> ExprKindGenData<'vir, Curr, Next> { ExprKindGenData::Result(ty) => ty, ExprKindGenData::AccField(_) => &TypeData::Bool, ExprKindGenData::Unfolding(f) => f.expr.ty(), - ExprKindGenData::UnOp(u) => u.expr.ty(), + ExprKindGenData::UnOp(u) => u.ty(), ExprKindGenData::BinOp(b) => b.ty(), + ExprKindGenData::SeqLiteral(s) => s.ty, ExprKindGenData::Ternary(t) => t.then.ty(), ExprKindGenData::Forall(_) => &TypeData::Bool, ExprKindGenData::Exists(_) => &TypeData::Bool, diff --git a/vir/src/genrefs.rs b/vir/src/genrefs.rs index 81691cbd839..169575731ea 100644 --- a/vir/src/genrefs.rs +++ b/vir/src/genrefs.rs @@ -28,6 +28,7 @@ pub type PredicateAppGen<'vir, Curr, Next> = pub type ProgramGen<'vir, Curr, Next> = &'vir crate::gendata::ProgramGenData<'vir, Curr, Next>; pub type PureAssignGen<'vir, Curr, Next> = &'vir crate::gendata::PureAssignGenData<'vir, Curr, Next>; +pub type SeqLiteralGen<'vir, Curr, Next> = &'vir crate::gendata::SeqLiteralGenData<'vir, Curr, Next>; pub type StmtGen<'vir, Curr, Next> = &'vir crate::gendata::StmtGenData<'vir, Curr, Next>; pub type StmtKindGen<'vir, Curr, Next> = &'vir crate::gendata::StmtKindGenData<'vir, Curr, Next>; pub type TerminatorStmtGen<'vir, Curr, Next> = diff --git a/vir/src/macros.rs b/vir/src/macros.rs index 82316ee23ef..190849cee00 100644 --- a/vir/src/macros.rs +++ b/vir/src/macros.rs @@ -179,16 +179,16 @@ macro_rules! vir_domain_axiom { let val_ex = $vcx.mk_local_ex("val", $crate::vir_type!($vcx; $ty)); let inner = $b.apply($vcx, [val_ex]); $vcx.mk_domain_axiom( - $vcx.alloc_str(&format!( - "ax_inverse_{}_{}", + $crate::ViperIdent::new($vcx.alloc_str(&format!( + "ax_inverise_{}_{}", $a.name(), $b.name(), - )), + ))), $vcx.mk_forall_expr( $vcx.alloc_slice(&[ $vcx.mk_local_decl("val", $crate::vir_type!($vcx; $ty)), ]), - $vcx.alloc_slice(&[$vcx.alloc_slice(&[inner])]), + $vcx.alloc_slice(&[$vcx.mk_trigger($vcx.alloc_slice(&[inner]))]), $vcx.mk_bin_op_expr( $crate::BinOpKind::CmpEq, $a.apply($vcx, [inner]), @@ -342,6 +342,18 @@ impl<'vir> ExprApply<'vir, crate::Expr<'vir>> vcx.mk_predicate_app_expr(self.apply(vcx, args, None)) } } +impl<'vir, const N: usize> ExprApply<'vir, crate::Expr<'vir>> + for crate::PredicateIdent<'vir, crate::KnownArity<'vir, N>> +{ + fn expr_apply( + &self, + vcx: &'vir crate::VirCtxt<'vir>, + args: &[crate::Expr<'vir>], + ) -> crate::Expr<'vir> { + assert_eq!(args.len(), N); + vcx.mk_predicate_app_expr(self.apply(vcx, args.try_into().unwrap(), None)) + } +} impl<'vir> ExprApply<'vir, crate::Expr<'vir>> for crate::Field<'vir> { fn expr_apply( &self, @@ -412,11 +424,9 @@ macro_rules! expr { ), $crate::expr!(@expr_one; $($rhs)*), )); } }; - (@expr($output:ident); acc( [ $outer:expr ]( $($args:tt)* ) ) ) => { { $output.push(vcx!().mk_predicate_app_expr( - $outer.apply(vcx!(), - $crate::expr!(@expr_list; $($args)*).as_slice(), - None, - ) + (@expr($output:ident); acc( [ $outer:expr ]( $($args:tt)* ) ) ) => { { $output.push($outer.expr_apply( + vcx!(), + $crate::expr!(@expr_list; $($args)*).as_slice(), )); } }; (@expr($output:ident); acc_field( [ $outer:expr ]( $($args:tt)* ) ) ) => { { $output.push(vcx!().mk_acc_field_expr( $crate::expr!(@expr_one; $($args)*), @@ -429,14 +439,24 @@ macro_rules! expr { Some(vcx!().mk_wildcard()), ) )); } }; + (@expr($output:ident); vpr_seq_len( $($args:tt)* ) ) => { { $output.push(vcx!().mk_unary_op_expr( + $crate::UnOpKind::SeqLen, + $crate::expr!(@expr_one; $($args)*), + )); } }; + (@expr($output:ident); old( $($args:tt)* ) ) => { { $output.push(vcx!().mk_old_expr( + $crate::expr!(@expr_one; $($args)*), + )); } }; + (@expr($output:ident); old_lhs( $($args:tt)* ) ) => { { $output.push(vcx!().mk_old_lhs_expr( + $crate::expr!(@expr_one; $($args)*), + )); } }; (@expr($output:ident); [ $outer:expr ]( ) ) => { { $output.push($outer.expr_apply( vcx!(), &[], )); } }; - (@expr($output:ident); [ $outer:expr ]( $($args:tt)* ) ) => { { $output.push($outer.expr_apply( + (@expr($output:ident); [ $outer:expr ]( $($args:tt)* ) $($rest:tt)* ) => { { $output.push($outer.expr_apply( vcx!(), $crate::expr!(@expr_list; $($args)*).as_slice(), - )); } }; + )); $crate::expr!(@expr_done($output); $($rest)*); } }; (@expr($output:ident); [ $outer:expr ] ) => { { $output.push($outer.expr(vcx!())); } }; (@expr($output:ident); ..[ $outer:expr ] ) => { { $output.extend($outer.iter().map(|e| e.expr(vcx!()))); } }; (@expr($output:ident); ( $($lhs:tt)+ ) => ( $($rhs:tt)+ )) => { { $output.push(vcx!().mk_bin_op_expr( @@ -476,6 +496,7 @@ macro_rules! expr { $crate::expr!(@expr_one; $($lhs)*), $crate::expr!(@expr_one; $($rhs)*), )); } }; + (@expr($output:ident); 0) => { { $output.push(vcx!().mk_uint::<0>()); } }; (@expr($output:ident); null) => { { $output.push(vcx!().mk_null()); } }; (@expr($output:ident); true) => { { $output.push(vcx!().mk_bool::()); } }; (@expr($output:ident); false) => { { $output.push(vcx!().mk_bool::()); } }; diff --git a/vir/src/make.rs b/vir/src/make.rs index 62629a46563..6526cb44197 100644 --- a/vir/src/make.rs +++ b/vir/src/make.rs @@ -157,6 +157,11 @@ cfg_if! { ExprKindGenData::UnOp(UnOpGenData { expr, .. }) => { check_expr_bindings(m, *expr); }, + ExprKindGenData::SeqLiteral(SeqLiteralGenData { values, .. }) => { + for value in values.iter() { + check_expr_bindings(m, value); + } + }, ExprKindGenData::Ternary(TernaryGenData { cond, then, else_}) => { check_expr_bindings(m, *cond); check_expr_bindings(m, *then); @@ -850,6 +855,17 @@ impl<'tcx> VirCtxt<'tcx> { .unwrap_or_else(|| self.mk_bool::()) } + pub fn mk_ty_seq<'vir>(&'vir self, elem_ty: Type<'vir>) -> Type<'vir> { + self.alloc(TypeData::Seq(elem_ty)) + } + + pub fn mk_seq_lit<'vir, Curr, Next>(&'vir self, values: &'vir [ExprGen<'vir, Curr, Next>], elem_ty: Type<'vir>) -> ExprGen<'vir, Curr, Next> { + self.alloc(ExprGenData::new(self.alloc(ExprKindGenData::SeqLiteral(self.alloc(SeqLiteralGenData { + values, + ty: self.mk_ty_seq(elem_ty), + }))))) + } + const fn get_int_data(rust_ty: &ty::TyKind) -> (u32, bool) { match rust_ty { ty::Int(ty::IntTy::Isize) => ((std::mem::size_of::() * 8) as u32, true), diff --git a/vir/src/refs.rs b/vir/src/refs.rs index bc61a1dbbe5..feb451475e8 100644 --- a/vir/src/refs.rs +++ b/vir/src/refs.rs @@ -26,6 +26,7 @@ pub type Predicate<'vir> = &'vir crate::data::PredicateData<'vir>; pub type PredicateApp<'vir> = &'vir crate::data::PredicateAppData<'vir>; pub type Program<'vir> = &'vir crate::data::ProgramData<'vir>; pub type PureAssign<'vir> = &'vir crate::data::PureAssignData<'vir>; +pub type SeqLiteral<'vir> = &'vir crate::data::SeqLiteralData<'vir>; pub type Stmt<'vir> = &'vir crate::data::StmtData<'vir>; pub type StmtKind<'vir> = &'vir crate::data::StmtKindData<'vir>; pub type TerminatorStmt<'vir> = &'vir crate::data::TerminatorStmtData<'vir>; diff --git a/vir/src/reify.rs b/vir/src/reify.rs index 3b9e487e6ae..3e9d4571a01 100644 --- a/vir/src/reify.rs +++ b/vir/src/reify.rs @@ -39,6 +39,7 @@ impl<'vir, Curr: Copy, NextA, NextB> Reify<'vir, Curr> } ExprKindGenData::UnOp(v) => vcx.alloc(ExprKindGenData::UnOp(v.reify(vcx, lctx))), ExprKindGenData::BinOp(v) => vcx.alloc(ExprKindGenData::BinOp(v.reify(vcx, lctx))), + ExprKindGenData::SeqLiteral(v) => vcx.alloc(ExprKindGenData::SeqLiteral(v.reify(vcx, lctx))), ExprKindGenData::Ternary(v) => vcx.alloc(ExprKindGenData::Ternary(v.reify(vcx, lctx))), ExprKindGenData::Forall(v) => vcx.alloc(ExprKindGenData::Forall(v.reify(vcx, lctx))), ExprKindGenData::Exists(v) => vcx.alloc(ExprKindGenData::Exists(v.reify(vcx, lctx))),