Skip to content

Commit e7c85cd

Browse files
authored
Adding support for type aliases to associated types (#120)
Co-authored-by: --global <--global>
1 parent a99c4c7 commit e7c85cd

File tree

14 files changed

+448
-13
lines changed

14 files changed

+448
-13
lines changed

prusti-encoder/src/encoders/mir_fn/mod.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pub use function::*;
66
pub use method::*;
77
pub use signature::*;
88

9-
use crate::encoders::ty::generics::{GArgs, GParams};
9+
use crate::encoders::ty::generics::{GArgs, GParams, trait_impls::TraitImplEnc};
1010

1111
use prusti_interface::specs::specifications::SpecQuery;
1212
use prusti_rustc_interface::{hir, middle::ty, span::def_id::DefId};
@@ -61,4 +61,13 @@ pub fn encode_all_in_crate<'tcx>(tcx: ty::TyCtxt<'tcx>) {
6161
}
6262
}
6363
}
64+
65+
// This creates the impl encoding for all traits in the crate
66+
// To iterate over all _visible_ impl blocks,
67+
// use tcx.visible_traits and tcx.all_impls(trait_id)
68+
for def_id in tcx.hir_crate_items(()).definitions() {
69+
if let hir::def::DefKind::Impl { of_trait: true } = tcx.def_kind(def_id) {
70+
TraitImplEnc::encode(def_id.to_def_id(), false).unwrap();
71+
}
72+
}
6473
}

prusti-encoder/src/encoders/ty/generics/args.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ pub struct GArgs<'tcx> {
1010
pub(super) args: &'tcx [ty::GenericArg<'tcx>],
1111
}
1212

13+
pub enum GParamVariant<'tcx> {
14+
Param(ty::ParamTy),
15+
Alias(ty::AliasTy<'tcx>),
16+
}
17+
1318
impl<'tcx> GArgs<'tcx> {
1419
pub fn new(context: impl Into<GParams<'tcx>>, args: &'tcx [ty::GenericArg<'tcx>]) -> Self {
1520
GArgs {
@@ -34,12 +39,11 @@ impl<'tcx> GArgs<'tcx> {
3439
self.context.normalize(ty)
3540
}
3641

37-
pub fn expect_param(self) -> ty::ParamTy {
42+
pub fn expect_param(self) -> GParamVariant<'tcx> {
3843
assert_eq!(self.args.len(), 1);
3944
match self.args[0].expect_ty().kind() {
40-
ty::TyKind::Param(p) => *p,
41-
// TODO: this needs to be changed to support type aliases
42-
ty::TyKind::Alias(..) => panic!("type aliases are not currently supported"),
45+
ty::TyKind::Param(p) => GParamVariant::Param(*p),
46+
ty::TyKind::Alias(_k, t) => GParamVariant::Alias(*t),
4347
other => panic!("expected type parameter, {other:?}"),
4448
}
4549
}

prusti-encoder/src/encoders/ty/generics/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ mod params;
33
mod casters;
44
mod args_ty;
55
mod args;
6+
pub mod traits;
7+
pub mod trait_impls;
68

79
pub use args::*;
810
pub use args_ty::*;

prusti-encoder/src/encoders/ty/generics/params.rs

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
use prusti_interface::specs::typed::ExternSpecKind;
22
use prusti_rustc_interface::{
3-
middle::ty,
3+
middle::{ty, ty::TyKind},
44
span::{def_id::DefId, symbol},
55
};
66
use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies};
77
use vir::{CastType, HasType};
88

99
use crate::encoders::{
1010
TyUsePureEnc,
11-
ty::{RustTyDecomposition, data::TySpecifics, generics::GArgsTyEnc, lifted::TyConstructorEnc},
11+
ty::{
12+
RustTyDecomposition,
13+
data::TySpecifics,
14+
generics::{GArgsTyEnc, GParamVariant, traits::TraitEnc},
15+
lifted::TyConstructorEnc,
16+
},
1217
};
1318

1419
/// The list of defined parameters in a given context. E.g. the type parameters
@@ -176,7 +181,6 @@ impl<'vir> From<DefId> for GParams<'vir> {
176181
/// `fn foo<T, U>(x: U)` into the Viper `method foo(x: Ref, T: Type, U: Type)`
177182
/// (handles the type parameters).
178183
pub struct GenericParamsEnc;
179-
180184
#[derive(Debug, Clone)]
181185
pub struct GenericParams<'vir> {
182186
ty_args: &'vir [vir::TypeTyVal<'vir>],
@@ -250,7 +254,26 @@ impl<'vir> GenericParams<'vir> {
250254
) -> vir::ExprTyVal<'vir> {
251255
if let TySpecifics::Param(()) = &ty.ty.specifics {
252256
let param = ty.args.expect_param();
253-
return self.ty_exprs[self.map_idx(param.index).unwrap()];
257+
return match param {
258+
GParamVariant::Param(p) => self.ty_exprs[self.map_idx(p.index).unwrap()],
259+
GParamVariant::Alias(a) => vir::with_vcx(|vcx| {
260+
let tcx = vcx.tcx();
261+
let trait_did = tcx.associated_item(a.def_id).container_id(tcx);
262+
let trait_data = deps.require_dep::<TraitEnc>(trait_did).unwrap();
263+
let tys = &a
264+
.args
265+
.iter()
266+
.map(|arg| match arg.expect_ty().kind() {
267+
TyKind::Param(p) => self.ty_exprs[self.map_idx(p.index).unwrap()],
268+
_ => self.ty_expr(
269+
deps,
270+
RustTyDecomposition::from_ty(arg.expect_ty(), tcx, ty.args.context),
271+
),
272+
})
273+
.collect::<Vec<_>>();
274+
(trait_data.type_did_fun_mapping.get(&a.def_id).unwrap())(tys)
275+
}),
276+
};
254277
}
255278
let ty_constructor = deps
256279
.require_ref::<TyConstructorEnc>(ty.ty)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
use std::iter;
2+
3+
use prusti_rustc_interface::{middle::ty::AssocKind, span::def_id::DefId};
4+
use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies};
5+
use vir::{CastType, Domain, vir_format_identifier};
6+
7+
use crate::encoders::ty::{
8+
RustTyDecomposition,
9+
generics::{GArgs, GArgsTyEnc, GParams, GenericParamsEnc, traits::TraitEnc},
10+
};
11+
12+
pub struct TraitImplEnc;
13+
14+
impl TaskEncoder for TraitImplEnc {
15+
task_encoder::encoder_cache!(TraitImplEnc);
16+
17+
fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir> {
18+
*task
19+
}
20+
21+
fn emit_outputs<'vir>(program: &mut task_encoder::Program<'vir>) {
22+
for dom in TraitImplEnc::all_outputs_local_no_errors() {
23+
program.add_domain(dom);
24+
}
25+
}
26+
27+
type TaskDescription<'vir> = DefId;
28+
type OutputFullLocal<'vir> = Domain<'vir>;
29+
30+
fn do_encode_full<'vir>(
31+
task_key: &Self::TaskKey<'vir>,
32+
deps: &mut TaskEncoderDependencies<'vir, Self>,
33+
) -> EncodeFullResult<'vir, Self> {
34+
deps.emit_output_ref(*task_key, ())?;
35+
36+
vir::with_vcx(|vcx| {
37+
let tcx = vcx.tcx();
38+
39+
let ctx = GParams::from(*task_key);
40+
41+
let params = deps.require_dep::<GenericParamsEnc>(ctx)?;
42+
43+
let trait_ref = tcx.impl_trait_ref(task_key).unwrap().instantiate_identity();
44+
let trait_did = trait_ref.def_id;
45+
let trait_data = deps.require_dep::<TraitEnc>(trait_did)?;
46+
47+
let args = deps.require_dep::<GArgsTyEnc>(GArgs::new(ctx, trait_ref.args))?;
48+
49+
let mut axs = Vec::new();
50+
51+
let struct_ty = tcx.type_of(task_key).instantiate_identity();
52+
53+
let impl_fun = trait_data.impl_fun;
54+
let trait_ty_decls = params
55+
.ty_decls()
56+
.iter()
57+
.map(|dec| dec.upcast_ty())
58+
.collect::<Vec<_>>();
59+
let trait_tys = args.get_ty();
60+
61+
axs.push(
62+
vcx.mk_domain_axiom(
63+
vir_format_identifier!(vcx, "{}_impl_{}", trait_data.trait_name, struct_ty),
64+
vir::expr! {forall ..[trait_ty_decls] :: {[impl_fun(trait_tys)]} [impl_fun(trait_tys)]}
65+
)
66+
);
67+
68+
tcx.associated_items(*task_key)
69+
.in_definition_order()
70+
.filter(|item| matches!(item.kind, AssocKind::Type { data: _ }))
71+
.for_each(|impl_item| {
72+
let assoc_fun = trait_data.type_did_fun_mapping.get(&impl_item.trait_item_def_id.unwrap()).unwrap();
73+
// construct arguments for assoc_item function
74+
// parameters of the trait are substituted
75+
// by the arguments used in the impl
76+
// parameters of the associated type are kept
77+
78+
// parameters of assoc item include already substituted arguments
79+
let assoc_params = deps
80+
.require_dep::<GenericParamsEnc>(GParams::from(impl_item.def_id))
81+
.unwrap();
82+
83+
// the type we want to resolve the type alias to
84+
let assoc_type_expr = assoc_params.ty_expr(
85+
deps,
86+
RustTyDecomposition::from_ty(
87+
tcx.type_of(impl_item.def_id).instantiate_identity(),
88+
tcx,
89+
GParams::from(impl_item.def_id),
90+
),
91+
);
92+
let assoc_decls = assoc_params
93+
.ty_decls()
94+
.iter()
95+
.map(|dec| dec.upcast_ty())
96+
.collect::<Vec<_>>();
97+
98+
// Combine substituted trait ty decls with the decls of the associated type
99+
let mut trait_ty_decls = trait_ty_decls.clone();
100+
trait_ty_decls.extend_from_slice(&assoc_decls[params.ty_exprs().len()..]);
101+
102+
// Combine substituted trait params with the params of the associated type
103+
let trait_tys = vcx.alloc_slice(&iter::empty().chain(args.get_ty().to_owned()).chain(assoc_params.ty_exprs()[params.ty_exprs().len()..].to_owned()).collect::<Vec<_>>());
104+
axs.push(vcx.mk_domain_axiom(
105+
vir_format_identifier!(
106+
vcx,
107+
"{}_Assoc_{}_{}",
108+
trait_data.trait_name,
109+
tcx.item_name(impl_item.def_id),
110+
struct_ty
111+
),
112+
vir::expr! {forall ..[trait_ty_decls] :: {[assoc_fun(trait_tys)]} ([assoc_fun(trait_tys)]) == (assoc_type_expr)}
113+
));
114+
});
115+
116+
Ok((
117+
vcx.mk_domain(
118+
vir_format_identifier!(
119+
vcx,
120+
"t_{}_{}",
121+
trait_data.trait_name,
122+
tcx.type_of(*task_key).instantiate_identity().to_string()
123+
),
124+
&[],
125+
vcx.alloc_slice(&axs),
126+
&[],
127+
None,
128+
),
129+
(),
130+
))
131+
})
132+
}
133+
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
use std::collections::HashMap;
2+
3+
use prusti_rustc_interface::{middle::ty::AssocKind, span::def_id::DefId};
4+
use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies};
5+
use vir::{FunctionIdn, vir_format_identifier};
6+
7+
use crate::encoders::ty::generics::{GParams, GenericParamsEnc};
8+
9+
pub struct TraitEnc;
10+
11+
#[derive(Debug, Clone)]
12+
pub struct TraitData<'vir> {
13+
pub trait_name: &'vir str,
14+
pub type_did_fun_mapping: HashMap<DefId, FunctionIdn<'vir, vir::ManyTyVal, vir::TyVal>>,
15+
pub impl_fun: FunctionIdn<'vir, vir::ManyTyVal, vir::Bool>,
16+
}
17+
18+
impl TaskEncoder for TraitEnc {
19+
task_encoder::encoder_cache!(TraitEnc);
20+
21+
fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir> {
22+
*task
23+
}
24+
25+
type TaskDescription<'vir> = DefId;
26+
27+
type OutputFullDependency<'vir> = TraitData<'vir>;
28+
type OutputFullLocal<'vir> = vir::Domain<'vir>;
29+
30+
fn emit_outputs<'vir>(program: &mut task_encoder::Program<'vir>) {
31+
for dom in TraitEnc::all_outputs_local_no_errors() {
32+
program.add_domain(dom);
33+
}
34+
}
35+
36+
fn do_encode_full<'vir>(
37+
task_key: &Self::TaskKey<'vir>,
38+
deps: &mut TaskEncoderDependencies<'vir, Self>,
39+
) -> EncodeFullResult<'vir, Self> {
40+
deps.emit_output_ref(*task_key, ())?;
41+
vir::with_vcx(|vcx| {
42+
let tcx = vcx.tcx();
43+
let params = deps.require_dep::<GenericParamsEnc>(GParams::from(*task_key))?;
44+
let trait_name = vcx.alloc_str(tcx.item_name(task_key).as_str());
45+
let type_did_fun_mapping = tcx
46+
.associated_items(task_key)
47+
.in_definition_order()
48+
.filter(|item| matches!(item.kind, AssocKind::Type { data: _ }))
49+
.map(|item| {
50+
let params_type = deps
51+
.require_dep::<GenericParamsEnc>(GParams::from(item.def_id))
52+
.unwrap();
53+
(
54+
item.def_id,
55+
FunctionIdn::new(
56+
vir_format_identifier!(
57+
vcx,
58+
"{}_Assoc_{}_func",
59+
trait_name,
60+
tcx.item_name(item.def_id),
61+
),
62+
vcx.alloc_slice(&vec![vir::TYPE_TYVAL; params_type.ty_exprs().len()]), // params_type also includes parameters of trait itself
63+
vir::TYPE_TYVAL,
64+
),
65+
)
66+
})
67+
.collect::<HashMap<DefId, FunctionIdn<'vir, vir::ManyTyVal, vir::TyVal>>>();
68+
let mut funcs = type_did_fun_mapping
69+
.values()
70+
.map(|function_idn| vcx.mk_domain_function(*function_idn, false, None))
71+
.collect::<Vec<_>>();
72+
let impl_fun = FunctionIdn::new(
73+
vir_format_identifier!(vcx, "{}_impl", trait_name),
74+
vcx.alloc_slice(&(vec![vir::TYPE_TYVAL; params.ty_exprs().len()])),
75+
vir::TYPE_BOOL,
76+
);
77+
let impl_fun_data = vcx.mk_domain_function(impl_fun, false, None);
78+
funcs.push(impl_fun_data);
79+
let trait_domain = vcx.mk_domain(
80+
vir_format_identifier!(vcx, "t_{}", trait_name),
81+
&[],
82+
&[],
83+
vcx.alloc_slice(funcs.as_slice()),
84+
None,
85+
);
86+
Ok((
87+
trait_domain,
88+
TraitData {
89+
trait_name,
90+
type_did_fun_mapping,
91+
impl_fun,
92+
},
93+
))
94+
})
95+
}
96+
}

prusti-encoder/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::encoders::{
2222
Impure, Pure,
2323
custom::PairUseEnc,
2424
ty::{
25-
generics::GArgsCastEnc,
25+
generics::{GArgsCastEnc, trait_impls::TraitImplEnc, traits::TraitEnc},
2626
interpretation::bitvec::BitVecEnc,
2727
lifted::{TyConstructorEnc, TypeOfEnc},
2828
},
@@ -104,6 +104,8 @@ pub fn test_entrypoint<'tcx>(
104104

105105
program.header("custom");
106106
PairUseEnc::emit_outputs(&mut program);
107+
TraitEnc::emit_outputs(&mut program);
108+
TraitImplEnc::emit_outputs(&mut program);
107109

108110
if std::env::var("LOCAL_TESTING").is_ok() {
109111
std::fs::write("local-testing/simple.vpr", program.code()).unwrap();
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
fn foo<Y: MyTrait>(x: Y::SomeType) {}
2+
3+
trait MyTrait {
4+
type SomeType;
5+
}
6+
7+
struct St1 {}
8+
struct St2 {}
9+
10+
impl MyTrait for St1 {
11+
type SomeType = u32;
12+
}
13+
14+
impl MyTrait for St2 {
15+
type SomeType = u64;
16+
}
17+
18+
fn bar() {
19+
foo::<St1>(5);
20+
}

0 commit comments

Comments
 (0)