diff --git a/prusti-encoder/src/encoders/mir_fn/mod.rs b/prusti-encoder/src/encoders/mir_fn/mod.rs index ba9d492bc79..0aee2668542 100644 --- a/prusti-encoder/src/encoders/mir_fn/mod.rs +++ b/prusti-encoder/src/encoders/mir_fn/mod.rs @@ -6,7 +6,7 @@ pub use function::*; pub use method::*; pub use signature::*; -use crate::encoders::ty::generics::{GArgs, GParams}; +use crate::encoders::ty::generics::{GArgs, GParams, trait_impls::TraitImplEnc}; use prusti_interface::specs::specifications::SpecQuery; use prusti_rustc_interface::{hir, middle::ty, span::def_id::DefId}; @@ -61,4 +61,13 @@ pub fn encode_all_in_crate<'tcx>(tcx: ty::TyCtxt<'tcx>) { } } } + + // This creates the impl encoding for all traits in the crate + // To iterate over all _visible_ impl blocks, + // use tcx.visible_traits and tcx.all_impls(trait_id) + for def_id in tcx.hir_crate_items(()).definitions() { + if let hir::def::DefKind::Impl { of_trait: true } = tcx.def_kind(def_id) { + TraitImplEnc::encode(def_id.to_def_id(), false).unwrap(); + } + } } diff --git a/prusti-encoder/src/encoders/ty/generics/args.rs b/prusti-encoder/src/encoders/ty/generics/args.rs index 5bf0cd8205c..f64e37ec43b 100644 --- a/prusti-encoder/src/encoders/ty/generics/args.rs +++ b/prusti-encoder/src/encoders/ty/generics/args.rs @@ -10,6 +10,11 @@ pub struct GArgs<'tcx> { pub(super) args: &'tcx [ty::GenericArg<'tcx>], } +pub enum GParamVariant<'tcx> { + Param(ty::ParamTy), + Alias(ty::AliasTy<'tcx>), +} + impl<'tcx> GArgs<'tcx> { pub fn new(context: impl Into>, args: &'tcx [ty::GenericArg<'tcx>]) -> Self { GArgs { @@ -34,12 +39,11 @@ impl<'tcx> GArgs<'tcx> { self.context.normalize(ty) } - pub fn expect_param(self) -> ty::ParamTy { + pub fn expect_param(self) -> GParamVariant<'tcx> { assert_eq!(self.args.len(), 1); match self.args[0].expect_ty().kind() { - ty::TyKind::Param(p) => *p, - // TODO: this needs to be changed to support type aliases - ty::TyKind::Alias(..) => panic!("type aliases are not currently supported"), + ty::TyKind::Param(p) => GParamVariant::Param(*p), + ty::TyKind::Alias(_k, t) => GParamVariant::Alias(*t), other => panic!("expected type parameter, {other:?}"), } } diff --git a/prusti-encoder/src/encoders/ty/generics/mod.rs b/prusti-encoder/src/encoders/ty/generics/mod.rs index 66aea9f38af..626ec23318a 100644 --- a/prusti-encoder/src/encoders/ty/generics/mod.rs +++ b/prusti-encoder/src/encoders/ty/generics/mod.rs @@ -3,6 +3,8 @@ mod params; mod casters; mod args_ty; mod args; +pub mod traits; +pub mod trait_impls; pub use args::*; pub use args_ty::*; diff --git a/prusti-encoder/src/encoders/ty/generics/params.rs b/prusti-encoder/src/encoders/ty/generics/params.rs index 0220487fada..4061df1d525 100644 --- a/prusti-encoder/src/encoders/ty/generics/params.rs +++ b/prusti-encoder/src/encoders/ty/generics/params.rs @@ -1,6 +1,6 @@ use prusti_interface::specs::typed::ExternSpecKind; use prusti_rustc_interface::{ - middle::ty, + middle::{ty, ty::TyKind}, span::{def_id::DefId, symbol}, }; use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; @@ -8,7 +8,12 @@ use vir::{CastType, HasType}; use crate::encoders::{ TyUsePureEnc, - ty::{RustTyDecomposition, data::TySpecifics, generics::GArgsTyEnc, lifted::TyConstructorEnc}, + ty::{ + RustTyDecomposition, + data::TySpecifics, + generics::{GArgsTyEnc, GParamVariant, traits::TraitEnc}, + lifted::TyConstructorEnc, + }, }; /// The list of defined parameters in a given context. E.g. the type parameters @@ -176,7 +181,6 @@ impl<'vir> From for GParams<'vir> { /// `fn foo(x: U)` into the Viper `method foo(x: Ref, T: Type, U: Type)` /// (handles the type parameters). pub struct GenericParamsEnc; - #[derive(Debug, Clone)] pub struct GenericParams<'vir> { ty_args: &'vir [vir::TypeTyVal<'vir>], @@ -250,7 +254,26 @@ impl<'vir> GenericParams<'vir> { ) -> vir::ExprTyVal<'vir> { if let TySpecifics::Param(()) = &ty.ty.specifics { let param = ty.args.expect_param(); - return self.ty_exprs[self.map_idx(param.index).unwrap()]; + return match param { + GParamVariant::Param(p) => self.ty_exprs[self.map_idx(p.index).unwrap()], + GParamVariant::Alias(a) => vir::with_vcx(|vcx| { + let tcx = vcx.tcx(); + let trait_did = tcx.associated_item(a.def_id).container_id(tcx); + let trait_data = deps.require_dep::(trait_did).unwrap(); + let tys = &a + .args + .iter() + .map(|arg| match arg.expect_ty().kind() { + TyKind::Param(p) => self.ty_exprs[self.map_idx(p.index).unwrap()], + _ => self.ty_expr( + deps, + RustTyDecomposition::from_ty(arg.expect_ty(), tcx, ty.args.context), + ), + }) + .collect::>(); + (trait_data.type_did_fun_mapping.get(&a.def_id).unwrap())(tys) + }), + }; } let ty_constructor = deps .require_ref::(ty.ty) diff --git a/prusti-encoder/src/encoders/ty/generics/trait_impls.rs b/prusti-encoder/src/encoders/ty/generics/trait_impls.rs new file mode 100644 index 00000000000..d6bd48e454e --- /dev/null +++ b/prusti-encoder/src/encoders/ty/generics/trait_impls.rs @@ -0,0 +1,133 @@ +use std::iter; + +use prusti_rustc_interface::{middle::ty::AssocKind, span::def_id::DefId}; +use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; +use vir::{CastType, Domain, vir_format_identifier}; + +use crate::encoders::ty::{ + RustTyDecomposition, + generics::{GArgs, GArgsTyEnc, GParams, GenericParamsEnc, traits::TraitEnc}, +}; + +pub struct TraitImplEnc; + +impl TaskEncoder for TraitImplEnc { + task_encoder::encoder_cache!(TraitImplEnc); + + fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir> { + *task + } + + fn emit_outputs<'vir>(program: &mut task_encoder::Program<'vir>) { + for dom in TraitImplEnc::all_outputs_local_no_errors() { + program.add_domain(dom); + } + } + + type TaskDescription<'vir> = DefId; + type OutputFullLocal<'vir> = Domain<'vir>; + + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(*task_key, ())?; + + vir::with_vcx(|vcx| { + let tcx = vcx.tcx(); + + let ctx = GParams::from(*task_key); + + let params = deps.require_dep::(ctx)?; + + let trait_ref = tcx.impl_trait_ref(task_key).unwrap().instantiate_identity(); + let trait_did = trait_ref.def_id; + let trait_data = deps.require_dep::(trait_did)?; + + let args = deps.require_dep::(GArgs::new(ctx, trait_ref.args))?; + + let mut axs = Vec::new(); + + let struct_ty = tcx.type_of(task_key).instantiate_identity(); + + let impl_fun = trait_data.impl_fun; + let trait_ty_decls = params + .ty_decls() + .iter() + .map(|dec| dec.upcast_ty()) + .collect::>(); + let trait_tys = args.get_ty(); + + axs.push( + vcx.mk_domain_axiom( + vir_format_identifier!(vcx, "{}_impl_{}", trait_data.trait_name, struct_ty), + vir::expr! {forall ..[trait_ty_decls] :: {[impl_fun(trait_tys)]} [impl_fun(trait_tys)]} + ) + ); + + tcx.associated_items(*task_key) + .in_definition_order() + .filter(|item| matches!(item.kind, AssocKind::Type { data: _ })) + .for_each(|impl_item| { + let assoc_fun = trait_data.type_did_fun_mapping.get(&impl_item.trait_item_def_id.unwrap()).unwrap(); + // construct arguments for assoc_item function + // parameters of the trait are substituted + // by the arguments used in the impl + // parameters of the associated type are kept + + // parameters of assoc item include already substituted arguments + let assoc_params = deps + .require_dep::(GParams::from(impl_item.def_id)) + .unwrap(); + + // the type we want to resolve the type alias to + let assoc_type_expr = assoc_params.ty_expr( + deps, + RustTyDecomposition::from_ty( + tcx.type_of(impl_item.def_id).instantiate_identity(), + tcx, + GParams::from(impl_item.def_id), + ), + ); + let assoc_decls = assoc_params + .ty_decls() + .iter() + .map(|dec| dec.upcast_ty()) + .collect::>(); + + // Combine substituted trait ty decls with the decls of the associated type + let mut trait_ty_decls = trait_ty_decls.clone(); + trait_ty_decls.extend_from_slice(&assoc_decls[params.ty_exprs().len()..]); + + // Combine substituted trait params with the params of the associated type + let trait_tys = vcx.alloc_slice(&iter::empty().chain(args.get_ty().to_owned()).chain(assoc_params.ty_exprs()[params.ty_exprs().len()..].to_owned()).collect::>()); + axs.push(vcx.mk_domain_axiom( + vir_format_identifier!( + vcx, + "{}_Assoc_{}_{}", + trait_data.trait_name, + tcx.item_name(impl_item.def_id), + struct_ty + ), + vir::expr! {forall ..[trait_ty_decls] :: {[assoc_fun(trait_tys)]} ([assoc_fun(trait_tys)]) == (assoc_type_expr)} + )); + }); + + Ok(( + vcx.mk_domain( + vir_format_identifier!( + vcx, + "t_{}_{}", + trait_data.trait_name, + tcx.type_of(*task_key).instantiate_identity().to_string() + ), + &[], + vcx.alloc_slice(&axs), + &[], + None, + ), + (), + )) + }) + } +} diff --git a/prusti-encoder/src/encoders/ty/generics/traits.rs b/prusti-encoder/src/encoders/ty/generics/traits.rs new file mode 100644 index 00000000000..c863f5f86c2 --- /dev/null +++ b/prusti-encoder/src/encoders/ty/generics/traits.rs @@ -0,0 +1,96 @@ +use std::collections::HashMap; + +use prusti_rustc_interface::{middle::ty::AssocKind, span::def_id::DefId}; +use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies}; +use vir::{FunctionIdn, vir_format_identifier}; + +use crate::encoders::ty::generics::{GParams, GenericParamsEnc}; + +pub struct TraitEnc; + +#[derive(Debug, Clone)] +pub struct TraitData<'vir> { + pub trait_name: &'vir str, + pub type_did_fun_mapping: HashMap>, + pub impl_fun: FunctionIdn<'vir, vir::ManyTyVal, vir::Bool>, +} + +impl TaskEncoder for TraitEnc { + task_encoder::encoder_cache!(TraitEnc); + + fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir> { + *task + } + + type TaskDescription<'vir> = DefId; + + type OutputFullDependency<'vir> = TraitData<'vir>; + type OutputFullLocal<'vir> = vir::Domain<'vir>; + + fn emit_outputs<'vir>(program: &mut task_encoder::Program<'vir>) { + for dom in TraitEnc::all_outputs_local_no_errors() { + program.add_domain(dom); + } + } + + fn do_encode_full<'vir>( + task_key: &Self::TaskKey<'vir>, + deps: &mut TaskEncoderDependencies<'vir, Self>, + ) -> EncodeFullResult<'vir, Self> { + deps.emit_output_ref(*task_key, ())?; + vir::with_vcx(|vcx| { + let tcx = vcx.tcx(); + let params = deps.require_dep::(GParams::from(*task_key))?; + let trait_name = vcx.alloc_str(tcx.item_name(task_key).as_str()); + let type_did_fun_mapping = tcx + .associated_items(task_key) + .in_definition_order() + .filter(|item| matches!(item.kind, AssocKind::Type { data: _ })) + .map(|item| { + let params_type = deps + .require_dep::(GParams::from(item.def_id)) + .unwrap(); + ( + item.def_id, + FunctionIdn::new( + vir_format_identifier!( + vcx, + "{}_Assoc_{}_func", + trait_name, + tcx.item_name(item.def_id), + ), + vcx.alloc_slice(&vec![vir::TYPE_TYVAL; params_type.ty_exprs().len()]), // params_type also includes parameters of trait itself + vir::TYPE_TYVAL, + ), + ) + }) + .collect::>>(); + let mut funcs = type_did_fun_mapping + .values() + .map(|function_idn| vcx.mk_domain_function(*function_idn, false, None)) + .collect::>(); + let impl_fun = FunctionIdn::new( + vir_format_identifier!(vcx, "{}_impl", trait_name), + vcx.alloc_slice(&(vec![vir::TYPE_TYVAL; params.ty_exprs().len()])), + vir::TYPE_BOOL, + ); + let impl_fun_data = vcx.mk_domain_function(impl_fun, false, None); + funcs.push(impl_fun_data); + let trait_domain = vcx.mk_domain( + vir_format_identifier!(vcx, "t_{}", trait_name), + &[], + &[], + vcx.alloc_slice(funcs.as_slice()), + None, + ); + Ok(( + trait_domain, + TraitData { + trait_name, + type_did_fun_mapping, + impl_fun, + }, + )) + }) + } +} diff --git a/prusti-encoder/src/lib.rs b/prusti-encoder/src/lib.rs index 014c6113637..3bfc17c0bef 100644 --- a/prusti-encoder/src/lib.rs +++ b/prusti-encoder/src/lib.rs @@ -21,7 +21,7 @@ use crate::encoders::{ Impure, Pure, custom::PairUseEnc, ty::{ - generics::GArgsCastEnc, + generics::{GArgsCastEnc, trait_impls::TraitImplEnc, traits::TraitEnc}, interpretation::bitvec::BitVecEnc, lifted::{TyConstructorEnc, TypeOfEnc}, }, @@ -103,6 +103,8 @@ pub fn test_entrypoint<'tcx>( program.header("custom"); PairUseEnc::emit_outputs(&mut program); + TraitEnc::emit_outputs(&mut program); + TraitImplEnc::emit_outputs(&mut program); if std::env::var("LOCAL_TESTING").is_ok() { std::fs::write("local-testing/simple.vpr", program.code()).unwrap(); diff --git a/prusti-tests/tests/verify/pass/traits/assoc_type.rs b/prusti-tests/tests/verify/pass/traits/assoc_type.rs new file mode 100644 index 00000000000..de09a75b58a --- /dev/null +++ b/prusti-tests/tests/verify/pass/traits/assoc_type.rs @@ -0,0 +1,20 @@ +fn foo(x: Y::SomeType) {} + +trait MyTrait { + type SomeType; +} + +struct St1 {} +struct St2 {} + +impl MyTrait for St1 { + type SomeType = u32; +} + +impl MyTrait for St2 { + type SomeType = u64; +} + +fn bar() { + foo::(5); +} diff --git a/prusti-tests/tests/verify/pass/traits/assoc_type_gen_assoc.rs b/prusti-tests/tests/verify/pass/traits/assoc_type_gen_assoc.rs new file mode 100644 index 00000000000..9e7c2e55957 --- /dev/null +++ b/prusti-tests/tests/verify/pass/traits/assoc_type_gen_assoc.rs @@ -0,0 +1,20 @@ +fn foo(x: Y::SomeType) {} + +trait MyTrait { + type SomeType; +} + +struct St1 {} +struct St2 {} + +impl MyTrait for St1 { + type SomeType = X; +} + +impl MyTrait for St2 { + type SomeType = u64; +} + +fn bar() { + foo::(5); +} diff --git a/prusti-tests/tests/verify/pass/traits/assoc_type_gen_trait.rs b/prusti-tests/tests/verify/pass/traits/assoc_type_gen_trait.rs new file mode 100644 index 00000000000..d50084655d0 --- /dev/null +++ b/prusti-tests/tests/verify/pass/traits/assoc_type_gen_trait.rs @@ -0,0 +1,44 @@ +fn foo, Z>(x: Y::SomeType, y: Y::SomeOtherType, z: Y) { + let res: X = z.gen(); +} + +trait MyTrait { + type SomeType; + type SomeOtherType; + + fn gen(self) -> T; +} + +struct St1 { + x: T, +} +struct St2 { + y: T, +} + +impl MyTrait for St1 { + type SomeType = T; + type SomeOtherType = T2; + + fn gen(self) -> T { + self.x + } +} + +impl MyTrait for St2 { + type SomeType = u64; + type SomeOtherType = SomeWrapper; + + fn gen(self) -> T { + self.y + } +} + +fn bar() { + foo::, u32>(5.2, 6, St1 { x: 5.5 }); + foo::, bool>(5, SomeWrapper { val: false }, St2 { y: true }); +} + +struct SomeWrapper { + val: T, +} diff --git a/prusti-tests/tests/verify/pass/traits/assoc_type_gen_trait_gen_assoc.rs b/prusti-tests/tests/verify/pass/traits/assoc_type_gen_trait_gen_assoc.rs new file mode 100644 index 00000000000..67e3365e2ba --- /dev/null +++ b/prusti-tests/tests/verify/pass/traits/assoc_type_gen_trait_gen_assoc.rs @@ -0,0 +1,44 @@ +fn foo, Z>(x: Y::SomeType>, y: Y::SomeOtherType, z: Y) { + let res: X = z.gen(); +} + +trait MyTrait { + type SomeType; + type SomeOtherType; + + fn gen(self) -> T; +} + +struct St1 { + x: T, +} +struct St2 { + y: T, +} + +impl MyTrait for St1 { + type SomeType = A; + type SomeOtherType = T2; + + fn gen(self) -> T { + self.x + } +} + +impl MyTrait for St2 { + type SomeType = T; + type SomeOtherType = SomeWrapper; + + fn gen(self) -> T { + self.y + } +} + +fn bar() { + foo::, u32>(SomeWrapper { val: 5.5 }, 6, St1 { x: 5.5 }); + foo::, bool>(true, SomeWrapper { val: false }, St2 { y: true }); +} + +struct SomeWrapper { + val: T, +} diff --git a/prusti-tests/tests/verify/pass/traits/gen_assoc_type.rs b/prusti-tests/tests/verify/pass/traits/gen_assoc_type.rs new file mode 100644 index 00000000000..058e84bbe50 --- /dev/null +++ b/prusti-tests/tests/verify/pass/traits/gen_assoc_type.rs @@ -0,0 +1,24 @@ +fn foo(x: Y::SomeType) {} + +trait MyTrait { + type SomeType; +} + +struct St1 {} +struct St2 {} + +impl MyTrait for St1 { + type SomeType = SomeWrapper; +} + +impl MyTrait for St2 { + type SomeType = u64; +} + +fn bar() { + foo::(SomeWrapper { val: 5 }); +} + +struct SomeWrapper { + val: T, +} diff --git a/prusti-tests/tests/verify/pass/traits/primitive_impl.rs b/prusti-tests/tests/verify/pass/traits/primitive_impl.rs new file mode 100644 index 00000000000..a17a87abe4c --- /dev/null +++ b/prusti-tests/tests/verify/pass/traits/primitive_impl.rs @@ -0,0 +1,15 @@ +fn foo>(x: Y::SomeOtherType) {} + +trait SomeTrait { + type SomeType; + type SomeOtherType; +} + +impl SomeTrait for u32 { + type SomeType = X; + type SomeOtherType = Y; +} + +fn bar() { + foo::(5); +} diff --git a/vir/src/macros.rs b/vir/src/macros.rs index bebd8ac21c5..adc0b93e83a 100644 --- a/vir/src/macros.rs +++ b/vir/src/macros.rs @@ -402,9 +402,8 @@ macro_rules! expr_inner { ) }; (@forall_qvars($qvars:ident); :: $($tokens:tt)*) => { compile_error!(concat!("VIR missing triggers or body: `" , stringify!($($tokens)*), "`")) }; - (@forall_qvars($qvars:ident); , ..[$outer:ident] $($tokens:tt)*) => { { - $qvars.extend($outer.iter().map(|local| $crate::CastType::as_dyn(vcx!().mk_local_decl_local(local)))); - let $outer: Vec<$crate::ExprGen<_, _, _>> = $outer.iter().map(|local| vcx!().mk_local_ex_local(local)).collect(); + (@forall_qvars($qvars:ident); , ..[$outer_decls:ident] $($tokens:tt)*) => { { + $qvars.extend($outer_decls.clone()); $crate::expr_inner!(@forall_qvars($qvars); $($tokens)*) } }; (@forall_qvars($qvars:ident); , $qvar:ident : $qtype:tt $($tokens:tt)* ) => { {