Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9a6e5f1
Add support for type aliases with associated types
ThomasMayerl Nov 21, 2025
2b77629
Run clippy
ThomasMayerl Nov 21, 2025
cfc09fd
Solve clippy issue
ThomasMayerl Nov 21, 2025
f1aee2a
Fix issue with associated types for primitive implementations
ThomasMayerl Nov 24, 2025
f5c131e
Move generation of trait/impl into its own encoder
ThomasMayerl Nov 25, 2025
24c66d7
Run clippy --fix
ThomasMayerl Nov 25, 2025
5ed73d9
Fix issue that trait/impl encodings use invalid identifiers
ThomasMayerl Nov 25, 2025
0fd9e87
Run clippy --fix
ThomasMayerl Nov 25, 2025
7a743e6
Add impl function for traits
ThomasMayerl Nov 25, 2025
cfaeef0
fix issues
ThomasMayerl Nov 27, 2025
fd627e4
Support generic traits
ThomasMayerl Nov 28, 2025
1d36e9d
Run clippy
ThomasMayerl Nov 28, 2025
1f10be8
Make trait and impl support different number of parameters
ThomasMayerl Nov 28, 2025
eb9ff51
Adapt test file
ThomasMayerl Nov 28, 2025
352b4ff
Adapt test file
ThomasMayerl Nov 28, 2025
d588579
Fix cycle issue for simple test cases
ThomasMayerl Dec 2, 2025
1f5b71a
Allow associated types to be generic
ThomasMayerl Dec 5, 2025
970de05
Run clippy
Dec 5, 2025
0ec6d6f
Generate domains for all implementations of traits (even for primitiv…
ThomasMayerl Dec 5, 2025
060201a
Run fmt on tests
ThomasMayerl Dec 8, 2025
81885ab
Implement PR feedback
ThomasMayerl Dec 8, 2025
0b21b46
use iterator chain instead of vec
ThomasMayerl Dec 8, 2025
d71b2a2
Implement feedback from PR
ThomasMayerl Dec 8, 2025
3e1b4b4
Use HashMap instead of slice of pairs
ThomasMayerl Dec 8, 2025
681da9f
Implement feedback from PR
ThomasMayerl Dec 12, 2025
20948c9
Fix indentation
ThomasMayerl Dec 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion prusti-encoder/src/encoders/mir_fn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub use function::*;
pub use method::*;
pub use signature::*;

use crate::encoders::ty::generics::{GArgs, GParams};
use crate::encoders::ty::generics::{GArgs, GParams, trait_impls::TraitImplEnc};

use prusti_interface::specs::specifications::SpecQuery;
use prusti_rustc_interface::{hir, middle::ty, span::def_id::DefId};
Expand Down Expand Up @@ -61,4 +61,13 @@ pub fn encode_all_in_crate<'tcx>(tcx: ty::TyCtxt<'tcx>) {
}
}
}

// This creates the impl encoding for all traits in the crate
// To iterate over all _visible_ impl blocks,
// use tcx.visible_traits and tcx.all_impls(trait_id)
for def_id in tcx.hir_crate_items(()).definitions() {
if let hir::def::DefKind::Impl { of_trait: true } = tcx.def_kind(def_id) {
TraitImplEnc::encode(def_id.to_def_id(), false).unwrap();
}
}
}
12 changes: 8 additions & 4 deletions prusti-encoder/src/encoders/ty/generics/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ pub struct GArgs<'tcx> {
pub(super) args: &'tcx [ty::GenericArg<'tcx>],
}

pub enum GParamVariant<'tcx> {
Param(ty::ParamTy),
Alias(ty::AliasTy<'tcx>),
}

impl<'tcx> GArgs<'tcx> {
pub fn new(context: impl Into<GParams<'tcx>>, args: &'tcx [ty::GenericArg<'tcx>]) -> Self {
GArgs {
Expand All @@ -34,12 +39,11 @@ impl<'tcx> GArgs<'tcx> {
self.context.normalize(ty)
}

pub fn expect_param(self) -> ty::ParamTy {
pub fn expect_param(self) -> GParamVariant<'tcx> {
assert_eq!(self.args.len(), 1);
match self.args[0].expect_ty().kind() {
ty::TyKind::Param(p) => *p,
// TODO: this needs to be changed to support type aliases
ty::TyKind::Alias(..) => panic!("type aliases are not currently supported"),
ty::TyKind::Param(p) => GParamVariant::Param(*p),
ty::TyKind::Alias(_k, t) => GParamVariant::Alias(*t),
other => panic!("expected type parameter, {other:?}"),
}
}
Expand Down
2 changes: 2 additions & 0 deletions prusti-encoder/src/encoders/ty/generics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ mod params;
mod casters;
mod args_ty;
mod args;
pub mod traits;
pub mod trait_impls;

pub use args::*;
pub use args_ty::*;
Expand Down
31 changes: 27 additions & 4 deletions prusti-encoder/src/encoders/ty/generics/params.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
use prusti_interface::specs::typed::ExternSpecKind;
use prusti_rustc_interface::{
middle::ty,
middle::{ty, ty::TyKind},
span::{def_id::DefId, symbol},
};
use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies};
use vir::{CastType, HasType};

use crate::encoders::{
TyUsePureEnc,
ty::{RustTyDecomposition, data::TySpecifics, generics::GArgsTyEnc, lifted::TyConstructorEnc},
ty::{
RustTyDecomposition,
data::TySpecifics,
generics::{GArgsTyEnc, GParamVariant, traits::TraitEnc},
lifted::TyConstructorEnc,
},
};

/// The list of defined parameters in a given context. E.g. the type parameters
Expand Down Expand Up @@ -176,7 +181,6 @@ impl<'vir> From<DefId> for GParams<'vir> {
/// `fn foo<T, U>(x: U)` into the Viper `method foo(x: Ref, T: Type, U: Type)`
/// (handles the type parameters).
pub struct GenericParamsEnc;

#[derive(Debug, Clone)]
pub struct GenericParams<'vir> {
ty_args: &'vir [vir::TypeTyVal<'vir>],
Expand Down Expand Up @@ -250,7 +254,26 @@ impl<'vir> GenericParams<'vir> {
) -> vir::ExprTyVal<'vir> {
if let TySpecifics::Param(()) = &ty.ty.specifics {
let param = ty.args.expect_param();
return self.ty_exprs[self.map_idx(param.index).unwrap()];
return match param {
GParamVariant::Param(p) => self.ty_exprs[self.map_idx(p.index).unwrap()],
GParamVariant::Alias(a) => vir::with_vcx(|vcx| {
let tcx = vcx.tcx();
let trait_did = tcx.associated_item(a.def_id).container_id(tcx);
let trait_data = deps.require_dep::<TraitEnc>(trait_did).unwrap();
let tys = &a
.args
.iter()
.map(|arg| match arg.expect_ty().kind() {
TyKind::Param(p) => self.ty_exprs[self.map_idx(p.index).unwrap()],
_ => self.ty_expr(
deps,
RustTyDecomposition::from_ty(arg.expect_ty(), tcx, ty.args.context),
),
})
.collect::<Vec<_>>();
(trait_data.type_did_fun_mapping.get(&a.def_id).unwrap())(tys)
}),
};
}
let ty_constructor = deps
.require_ref::<TyConstructorEnc>(ty.ty)
Expand Down
133 changes: 133 additions & 0 deletions prusti-encoder/src/encoders/ty/generics/trait_impls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
use std::iter;

use prusti_rustc_interface::{middle::ty::AssocKind, span::def_id::DefId};
use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies};
use vir::{CastType, Domain, vir_format_identifier};

use crate::encoders::ty::{
RustTyDecomposition,
generics::{GArgs, GArgsTyEnc, GParams, GenericParamsEnc, traits::TraitEnc},
};

pub struct TraitImplEnc;

impl TaskEncoder for TraitImplEnc {
task_encoder::encoder_cache!(TraitImplEnc);

fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir> {
*task
}

fn emit_outputs<'vir>(program: &mut task_encoder::Program<'vir>) {
for dom in TraitImplEnc::all_outputs_local_no_errors() {
program.add_domain(dom);
}
}

type TaskDescription<'vir> = DefId;
type OutputFullLocal<'vir> = Domain<'vir>;

fn do_encode_full<'vir>(
task_key: &Self::TaskKey<'vir>,
deps: &mut TaskEncoderDependencies<'vir, Self>,
) -> EncodeFullResult<'vir, Self> {
deps.emit_output_ref(*task_key, ())?;

vir::with_vcx(|vcx| {
let tcx = vcx.tcx();

let ctx = GParams::from(*task_key);

let params = deps.require_dep::<GenericParamsEnc>(ctx)?;

let trait_ref = tcx.impl_trait_ref(task_key).unwrap().instantiate_identity();
let trait_did = trait_ref.def_id;
let trait_data = deps.require_dep::<TraitEnc>(trait_did)?;

let args = deps.require_dep::<GArgsTyEnc>(GArgs::new(ctx, trait_ref.args))?;

let mut axs = Vec::new();

let struct_ty = tcx.type_of(task_key).instantiate_identity();

let impl_fun = trait_data.impl_fun;
let trait_ty_decls = params
.ty_decls()
.iter()
.map(|dec| dec.upcast_ty())
.collect::<Vec<_>>();
let trait_tys = args.get_ty();

axs.push(
vcx.mk_domain_axiom(
vir_format_identifier!(vcx, "{}_impl_{}", trait_data.trait_name, struct_ty),
vir::expr! {forall ..[trait_ty_decls] :: {[impl_fun(trait_tys)]} [impl_fun(trait_tys)]}
)
);

tcx.associated_items(*task_key)
.in_definition_order()
.filter(|item| matches!(item.kind, AssocKind::Type { data: _ }))
.for_each(|impl_item| {
let assoc_fun = trait_data.type_did_fun_mapping.get(&impl_item.trait_item_def_id.unwrap()).unwrap();
// construct arguments for assoc_item function
// parameters of the trait are substituted
// by the arguments used in the impl
// parameters of the associated type are kept

// parameters of assoc item include already substituted arguments
let assoc_params = deps
.require_dep::<GenericParamsEnc>(GParams::from(impl_item.def_id))
.unwrap();

// the type we want to resolve the type alias to
let assoc_type_expr = assoc_params.ty_expr(
deps,
RustTyDecomposition::from_ty(
tcx.type_of(impl_item.def_id).instantiate_identity(),
tcx,
GParams::from(impl_item.def_id),
),
);
let assoc_decls = assoc_params
.ty_decls()
.iter()
.map(|dec| dec.upcast_ty())
.collect::<Vec<_>>();

// Combine substituted trait ty decls with the decls of the associated type
let mut trait_ty_decls = trait_ty_decls.clone();
trait_ty_decls.extend_from_slice(&assoc_decls[params.ty_exprs().len()..]);

// Combine substituted trait params with the params of the associated type
let trait_tys = vcx.alloc_slice(&iter::empty().chain(args.get_ty().to_owned()).chain(assoc_params.ty_exprs()[params.ty_exprs().len()..].to_owned()).collect::<Vec<_>>());
axs.push(vcx.mk_domain_axiom(
vir_format_identifier!(
vcx,
"{}_Assoc_{}_{}",
trait_data.trait_name,
tcx.item_name(impl_item.def_id),
struct_ty
),
vir::expr! {forall ..[trait_ty_decls] :: {[assoc_fun(trait_tys)]} ([assoc_fun(trait_tys)]) == (assoc_type_expr)}
));
});

Ok((
vcx.mk_domain(
vir_format_identifier!(
vcx,
"t_{}_{}",
trait_data.trait_name,
tcx.type_of(*task_key).instantiate_identity().to_string()
),
&[],
vcx.alloc_slice(&axs),
&[],
None,
),
(),
))
})
}
}
96 changes: 96 additions & 0 deletions prusti-encoder/src/encoders/ty/generics/traits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::collections::HashMap;

use prusti_rustc_interface::{middle::ty::AssocKind, span::def_id::DefId};
use task_encoder::{EncodeFullResult, TaskEncoder, TaskEncoderDependencies};
use vir::{FunctionIdn, vir_format_identifier};

use crate::encoders::ty::generics::{GParams, GenericParamsEnc};

pub struct TraitEnc;

#[derive(Debug, Clone)]
pub struct TraitData<'vir> {
pub trait_name: &'vir str,
pub type_did_fun_mapping: HashMap<DefId, FunctionIdn<'vir, vir::ManyTyVal, vir::TyVal>>,
pub impl_fun: FunctionIdn<'vir, vir::ManyTyVal, vir::Bool>,
}

impl TaskEncoder for TraitEnc {
task_encoder::encoder_cache!(TraitEnc);

fn task_to_key<'vir>(task: &Self::TaskDescription<'vir>) -> Self::TaskKey<'vir> {
*task
}

type TaskDescription<'vir> = DefId;

type OutputFullDependency<'vir> = TraitData<'vir>;
type OutputFullLocal<'vir> = vir::Domain<'vir>;

fn emit_outputs<'vir>(program: &mut task_encoder::Program<'vir>) {
for dom in TraitEnc::all_outputs_local_no_errors() {
program.add_domain(dom);
}
}

fn do_encode_full<'vir>(
task_key: &Self::TaskKey<'vir>,
deps: &mut TaskEncoderDependencies<'vir, Self>,
) -> EncodeFullResult<'vir, Self> {
deps.emit_output_ref(*task_key, ())?;
vir::with_vcx(|vcx| {
let tcx = vcx.tcx();
let params = deps.require_dep::<GenericParamsEnc>(GParams::from(*task_key))?;
let trait_name = vcx.alloc_str(tcx.item_name(task_key).as_str());
let type_did_fun_mapping = tcx
.associated_items(task_key)
.in_definition_order()
.filter(|item| matches!(item.kind, AssocKind::Type { data: _ }))
.map(|item| {
let params_type = deps
.require_dep::<GenericParamsEnc>(GParams::from(item.def_id))
.unwrap();
(
item.def_id,
FunctionIdn::new(
vir_format_identifier!(
vcx,
"{}_Assoc_{}_func",
trait_name,
tcx.item_name(item.def_id),
),
vcx.alloc_slice(&vec![vir::TYPE_TYVAL; params_type.ty_exprs().len()]), // params_type also includes parameters of trait itself
vir::TYPE_TYVAL,
),
)
})
.collect::<HashMap<DefId, FunctionIdn<'vir, vir::ManyTyVal, vir::TyVal>>>();
let mut funcs = type_did_fun_mapping
.values()
.map(|function_idn| vcx.mk_domain_function(*function_idn, false, None))
.collect::<Vec<_>>();
let impl_fun = FunctionIdn::new(
vir_format_identifier!(vcx, "{}_impl", trait_name),
vcx.alloc_slice(&(vec![vir::TYPE_TYVAL; params.ty_exprs().len()])),
vir::TYPE_BOOL,
);
let impl_fun_data = vcx.mk_domain_function(impl_fun, false, None);
funcs.push(impl_fun_data);
let trait_domain = vcx.mk_domain(
vir_format_identifier!(vcx, "t_{}", trait_name),
&[],
&[],
vcx.alloc_slice(funcs.as_slice()),
None,
);
Ok((
trait_domain,
TraitData {
trait_name,
type_did_fun_mapping,
impl_fun,
},
))
})
}
}
4 changes: 3 additions & 1 deletion prusti-encoder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::encoders::{
Impure, Pure,
custom::PairUseEnc,
ty::{
generics::GArgsCastEnc,
generics::{GArgsCastEnc, trait_impls::TraitImplEnc, traits::TraitEnc},
interpretation::bitvec::BitVecEnc,
lifted::{TyConstructorEnc, TypeOfEnc},
},
Expand Down Expand Up @@ -103,6 +103,8 @@ pub fn test_entrypoint<'tcx>(

program.header("custom");
PairUseEnc::emit_outputs(&mut program);
TraitEnc::emit_outputs(&mut program);
TraitImplEnc::emit_outputs(&mut program);

if std::env::var("LOCAL_TESTING").is_ok() {
std::fs::write("local-testing/simple.vpr", program.code()).unwrap();
Expand Down
20 changes: 20 additions & 0 deletions prusti-tests/tests/verify/pass/traits/assoc_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
fn foo<Y: MyTrait>(x: Y::SomeType) {}

trait MyTrait {
type SomeType;
}

struct St1 {}
struct St2 {}

impl MyTrait for St1 {
type SomeType = u32;
}

impl MyTrait for St2 {
type SomeType = u64;
}

fn bar() {
foo::<St1>(5);
}
Loading
Loading