Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions compiler/noirc_frontend/src/elaborator/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ impl Elaborator<'_> {
) {
// Set up trait impl state
self.current_trait_impl = trait_impl.impl_id;
self.current_trait = trait_impl.trait_id;
self.self_type = trait_impl.methods.self_type.clone();
self.generics = generics;

Expand All @@ -137,6 +138,7 @@ impl Elaborator<'_> {
// Cleanup
self.self_type = None;
self.current_trait_impl = None;
self.current_trait = None;
self.generics.clear();
}

Expand Down Expand Up @@ -176,6 +178,13 @@ impl Elaborator<'_> {

let mut trait_constraints =
self.resolve_trait_constraints_and_add_to_scope(&func.def.where_clause);

// Add constraints for parent traits that have associated types.
let (parent_generics, parent_constraints) =
self.add_parent_associated_type_constraints(&trait_constraints);
generics.extend(parent_generics);
trait_constraints.extend(parent_constraints);

let mut extra_trait_constraints =
vecmap(extra_trait_constraints, |(constraint, _)| constraint.clone());
extra_trait_constraints.extend(associated_generics_trait_constraints);
Expand Down Expand Up @@ -486,6 +495,7 @@ impl Elaborator<'_> {
self.local_module = Some(func_meta.source_module);
self.self_type = func_meta.self_type.clone();
self.current_trait_impl = func_meta.trait_impl;
self.current_trait = func_meta.trait_id;

self.scopes.start_function();
let old_item = self.current_item.replace(DependencyId::Function(id));
Expand Down
6 changes: 5 additions & 1 deletion compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ pub struct Elaborator<'context> {
/// to the corresponding trait impl ID.
current_trait_impl: Option<TraitImplId>,

/// The trait we're currently resolving, if we are resolving one.
/// The trait we're currently resolving or implementing, if any.
/// Set during both trait definitions (`trait Foo { ... }`) and
/// trait impl elaboration (`impl Foo for Bar { ... }`).
current_trait: Option<TraitId>,

/// In-resolution names
Expand Down Expand Up @@ -661,6 +663,7 @@ impl<'context> Elaborator<'context> {

self.generics = trait_impl.resolved_generics.clone();
self.current_trait_impl = trait_impl.impl_id;
self.current_trait = trait_impl.trait_id;

self.add_trait_impl_assumed_trait_implementations(trait_impl.impl_id);
self.check_trait_impl_where_clause_matches_trait_where_clause(&trait_impl);
Expand All @@ -681,6 +684,7 @@ impl<'context> Elaborator<'context> {

self.self_type = None;
self.current_trait_impl = None;
self.current_trait = None;
self.generics.clear();
}

Expand Down
3 changes: 3 additions & 0 deletions compiler/noirc_frontend/src/elaborator/trait_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ impl Elaborator<'_> {
let previous_local_module = self.local_module.replace(trait_impl.module_id);
let previous_current_trait_impl =
std::mem::replace(&mut self.current_trait_impl, trait_impl.impl_id);
let previous_current_trait =
std::mem::replace(&mut self.current_trait, trait_impl.trait_id);

let self_type = trait_impl.methods.self_type.clone();
let self_type =
Expand Down Expand Up @@ -270,6 +272,7 @@ impl Elaborator<'_> {

self.local_module = previous_local_module;
self.current_trait_impl = previous_current_trait_impl;
self.current_trait = previous_current_trait;
self.self_type = previous_self_type;
}

Expand Down
126 changes: 125 additions & 1 deletion compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ use crate::{
},
hir_def::{
function::FuncMeta,
traits::{ResolvedTraitBound, TraitConstraint, TraitFunction},
traits::{NamedType, ResolvedTraitBound, TraitConstraint, TraitFunction},
},
node_interner::{
DependencyId, FuncId, ImplSearchErrorKind, NodeInterner, ReferenceId, TraitId,
Expand Down Expand Up @@ -589,6 +589,130 @@ impl Elaborator<'_> {
Some(constraint)
}

/// For each resolved trait constraint, add constraints for parent traits that have
/// associated types. This creates fresh type variables for the parent associated types
/// so that `M::Key` syntax can be resolved via `self.trait_bounds`.
///
/// The parent trait bounds are obtained from `Trait.trait_bounds` (already resolved
/// during `collect_traits` with associated type variables) and instantiated via
/// `instantiate_parent_trait_bound` to substitute the child trait's bindings. The
/// named (associated) types are then replaced with fresh per-function type variables
/// so they can be wrapped in `Type::Forall` and freshened at each call site.
///
/// Returns (new_generics, new_constraints) to be added to the function's generics
/// and trait constraints respectively.
pub(super) fn add_parent_associated_type_constraints(
&mut self,
constraints: &[TraitConstraint],
) -> (Vec<TypeVariable>, Vec<TraitConstraint>) {
let mut new_generics = Vec::new();
let mut new_constraints = Vec::new();
let mut visited = rustc_hash::FxHashSet::default();

for constraint in constraints {
self.collect_parent_associated_types(
&constraint.typ,
&constraint.trait_bound,
&mut new_generics,
&mut new_constraints,
&mut visited,
);
visited.clear();
}

(new_generics, new_constraints)
}

/// Recursively walk parent trait hierarchies and create fresh type variables
/// for any associated types found on parent traits. The new constraints are
/// pushed to `self.trait_bounds` and returned via the output parameters.
fn collect_parent_associated_types(
&mut self,
object_type: &Type,
trait_bound: &ResolvedTraitBound,
new_generics: &mut Vec<TypeVariable>,
new_constraints: &mut Vec<TraitConstraint>,
visited: &mut rustc_hash::FxHashSet<TraitId>,
) {
let trait_id = trait_bound.trait_id;
if !visited.insert(trait_id) {
return;
}

let parent_bounds: Vec<_> = self
.interner
.try_get_trait(trait_id)
.map(|t| t.trait_bounds.clone())
.unwrap_or_default();

for parent_bound in &parent_bounds {
// Substitute the child trait's bindings into the parent bound.
let instantiated = self.instantiate_parent_trait_bound(trait_bound, parent_bound);

// Skip if there are no associated types on this parent trait,
// or if we already have a constraint for this type + parent trait.
let has_named = !instantiated.trait_generics.named.is_empty();
let already_has = self
.trait_bounds
.iter()
.any(|c| c.trait_bound.trait_id == instantiated.trait_id && c.typ == *object_type);

if has_named && !already_has {
// Replace the named (associated) type variables with fresh per-function
// ones so they can be included in Type::Forall and freshened at call sites.
let parent_trait = self.interner.get_trait(instantiated.trait_id);
let parent_trait_name = parent_trait.name.to_string();
let object_name = object_type.to_string();

let named = vecmap(&instantiated.trait_generics.named, |named_type| {
let fresh_id = self.interner.next_type_variable_id();
let kind = named_type.typ.kind();
let type_var = TypeVariable::unbound(fresh_id, kind);

let assoc_type_id = parent_trait
.associated_types
.iter()
.find(|a| a.name.as_ref() == named_type.name.as_str())
.map_or(fresh_id, |a| a.type_var.id());

let fresh_type = type_var.clone().into_implicit_named_generic(
&Rc::new(named_type.name.to_string()),
Some((object_name.as_str(), parent_trait_name.as_str())),
assoc_type_id,
);

new_generics.push(type_var);
NamedType {
name: Ident::new(named_type.name.to_string(), instantiated.location),
typ: fresh_type,
}
});

let trait_generics =
TraitGenerics { ordered: instantiated.trait_generics.ordered.clone(), named };
let parent_constraint = TraitConstraint {
typ: object_type.clone(),
trait_bound: ResolvedTraitBound {
trait_id: instantiated.trait_id,
trait_generics,
location: instantiated.location,
},
};
self.trait_bounds.push(parent_constraint.clone());
new_constraints.push(parent_constraint);
}

// Recurse for grandparent traits
self.collect_parent_associated_types(
object_type,
&instantiated,
new_generics,
new_constraints,
visited,
);
}
}

/// Adds an assumed trait implementation for the given object type and trait bound.
///
/// This also recursively adds assumed implementations for any parent traits.
Expand Down
Loading
Loading