Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 110 additions & 3 deletions compiler-core/checking/src/algorithm/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ where
Ok(residual)
}

#[derive(Clone, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct ConstraintApplication {
pub(crate) file_id: FileId,
pub(crate) item_id: TypeItemId,
Expand Down Expand Up @@ -469,7 +469,7 @@ fn match_type(
let given_core = &state.storage[given];

match (wanted_core, given_core) {
(_, Type::Variable(Variable::Bound(level))) => {
(_, Type::Variable(Variable::Bound(level, _))) => {
if let Some(&bound) = bindings.get(level) {
match can_unify(state, wanted, bound) {
CanUnify::Equal => MatchType::Match,
Expand Down Expand Up @@ -552,6 +552,28 @@ fn match_given_type(state: &mut CheckState, wanted: TypeId, given: TypeId) -> Ma
match (wanted_core, given_core) {
(Type::Unification(_), _) => MatchType::Stuck,

(
Type::Variable(Variable::Bound(w_level, w_kind)),
Type::Variable(Variable::Bound(g_level, g_kind)),
) => {
if w_level == g_level {
match_given_type(state, *w_kind, *g_kind)
} else {
MatchType::Apart
}
}

(
Type::Variable(Variable::Skolem(w_level, w_kind)),
Type::Variable(Variable::Skolem(g_level, g_kind)),
) => {
if w_level == g_level {
match_given_type(state, *w_kind, *g_kind)
} else {
MatchType::Apart
}
}

(
&Type::Application(w_function, w_argument),
&Type::Application(g_function, g_argument),
Expand Down Expand Up @@ -668,6 +690,22 @@ fn can_unify(state: &mut CheckState, t1: TypeId, t2: TypeId) -> CanUnify {
}
}

(&Type::Variable(Variable::Bound(l1, k1)), &Type::Variable(Variable::Bound(l2, k2))) => {
if l1 == l2 {
can_unify(state, k1, k2)
} else {
Apart
}
}

(&Type::Variable(Variable::Skolem(l1, k1)), &Type::Variable(Variable::Skolem(l2, k2))) => {
if l1 == l2 {
can_unify(state, k1, k2)
} else {
Apart
}
}

_ => Apart,
}
}
Expand Down Expand Up @@ -718,6 +756,25 @@ where
}
}

let mut argument_levels = FxHashSet::default();
for &(argument, _) in &instance.arguments {
let localized = transfer::localize(state, context, argument);
CollectBoundLevels::on(state, localized, &mut argument_levels);
}

let mut constraint_variables = FxHashMap::default();
for &(constraint, _) in &instance.constraints {
let localized = transfer::localize(state, context, constraint);
CollectBoundVariables::on(state, localized, &mut constraint_variables);
}

for (level, kind) in constraint_variables {
if !argument_levels.contains(&level) && !bindings.contains_key(&level) {
let unification = state.fresh_unification_kinded(kind);
bindings.insert(level, unification);
}
}

let constraints = instance
.constraints
.iter()
Expand Down Expand Up @@ -880,11 +937,61 @@ impl<'a> ApplyBindings<'a> {
impl TypeFold for ApplyBindings<'_> {
fn transform(&mut self, _state: &mut CheckState, id: TypeId, t: &Type) -> FoldAction {
match t {
Type::Variable(Variable::Bound(level)) => {
Type::Variable(Variable::Bound(level, _)) => {
let id = self.bindings.get(level).copied().unwrap_or(id);
FoldAction::Replace(id)
}
_ => FoldAction::Continue,
}
}
}

/// Collects all bound variable levels from a type.
struct CollectBoundLevels<'a> {
levels: &'a mut FxHashSet<debruijn::Level>,
}

impl<'a> CollectBoundLevels<'a> {
fn on(state: &mut CheckState, type_id: TypeId, levels: &'a mut FxHashSet<debruijn::Level>) {
fold_type(state, type_id, &mut CollectBoundLevels { levels });
}
}

impl TypeFold for CollectBoundLevels<'_> {
fn transform(&mut self, _state: &mut CheckState, id: TypeId, t: &Type) -> FoldAction {
match t {
Type::Variable(Variable::Bound(level, _)) => {
self.levels.insert(*level);
FoldAction::Replace(id)
}
_ => FoldAction::Continue,
}
}
}

/// Collects all bound variables with their kinds from a type.
struct CollectBoundVariables<'a> {
variables: &'a mut FxHashMap<debruijn::Level, TypeId>,
}

impl<'a> CollectBoundVariables<'a> {
fn on(
state: &mut CheckState,
type_id: TypeId,
variables: &'a mut FxHashMap<debruijn::Level, TypeId>,
) {
fold_type(state, type_id, &mut CollectBoundVariables { variables });
}
}

impl TypeFold for CollectBoundVariables<'_> {
fn transform(&mut self, _state: &mut CheckState, id: TypeId, t: &Type) -> FoldAction {
match t {
Type::Variable(Variable::Bound(level, kind)) => {
self.variables.insert(*level, *kind);
FoldAction::Replace(id)
}
_ => FoldAction::Continue,
}
}
}
14 changes: 4 additions & 10 deletions compiler-core/checking/src/algorithm/derive/higher_kinded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,11 @@ where
return false;
};

let Some(kind) = lookup_variable_kind(state, variable) else {
return false;
let kind = match variable {
Variable::Skolem(_, kind) => *kind,
Variable::Bound(_, kind) => *kind,
Variable::Free(_) => context.prim.unknown,
};

Zonk::on(state, kind) == context.prim.type_to_type
}

fn lookup_variable_kind(state: &CheckState, variable: &Variable) -> Option<TypeId> {
match variable {
Variable::Skolem(_, kind) => Some(*kind),
Variable::Bound(level) => state.type_scope.kinds.get(*level).copied(),
Variable::Free(_) => None,
}
}
14 changes: 12 additions & 2 deletions compiler-core/checking/src/algorithm/fold.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use crate::algorithm::state::CheckState;
use crate::core::{ForallBinder, RowType, Type, TypeId};
use crate::core::{ForallBinder, RowType, Type, TypeId, Variable};

/// Controls behavior during type folding.
pub enum FoldAction {
Expand Down Expand Up @@ -103,7 +103,17 @@ pub fn fold_type<F: TypeFold>(state: &mut CheckState, id: TypeId, folder: &mut F
state.storage.intern(Type::SynonymApplication(saturation, file_id, type_id, arguments))
}
Type::Unification(_) => id,
Type::Variable(_) => id,
Type::Variable(variable) => match variable {
Variable::Bound(level, kind) => {
let kind = fold_type(state, kind, folder);
state.storage.intern(Type::Variable(Variable::Bound(level, kind)))
}
Variable::Skolem(level, kind) => {
let kind = fold_type(state, kind, folder);
state.storage.intern(Type::Variable(Variable::Skolem(level, kind)))
}
Variable::Free(_) => id,
},
Type::Unknown => id,
}
}
6 changes: 4 additions & 2 deletions compiler-core/checking/src/algorithm/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ where

state.type_scope.kinds.insert(new_level, binder.kind);

let variable = state.storage.intern(Type::Variable(Variable::Bound(new_level)));
let variable =
state.storage.intern(Type::Variable(Variable::Bound(new_level, binder.kind)));
let inner = substitute::SubstituteBound::on(state, old_level, variable, inner);

variables.push(binder);
Expand Down Expand Up @@ -142,7 +143,8 @@ where
let level = state.type_scope.bound.bind(debruijn::Variable::Core);
state.type_scope.kinds.insert(level, binder_kind);

let variable = state.storage.intern(Type::Variable(Variable::Bound(level)));
let variable =
state.storage.intern(Type::Variable(Variable::Bound(level, binder_kind)));
current_id = substitute::SubstituteBound::on(state, binder_level, variable, inner);
}

Expand Down
30 changes: 14 additions & 16 deletions compiler-core/checking/src/algorithm/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,14 @@ fn infer_forall_variable(
) -> (TypeId, TypeId) {
let level =
state.type_scope.lookup_forall(forall).expect("invariant violated: TypeScope::bind_forall");
let variable = Variable::Bound(level);

let t = state.storage.intern(Type::Variable(variable));
let k = state
.type_scope
.lookup_forall_kind(forall)
.expect("invariant violated: TypeScope::bind_forall");

let variable = Variable::Bound(level, k);
let t = state.storage.intern(Type::Variable(variable));

(t, k)
}

Expand All @@ -274,26 +274,26 @@ fn infer_implicit_variable<Q: ExternalQueries>(
context: &CheckContext<Q>,
implicit: &lowering::ImplicitTypeVariable,
) -> (TypeId, TypeId) {
let t = if implicit.binding {
let (t, k) = if implicit.binding {
let kind = state.fresh_unification(context);

let level = state.type_scope.bind_implicit(implicit.node, implicit.id, kind);
let variable = Variable::Bound(level);
let variable = Variable::Bound(level, kind);

state.storage.intern(Type::Variable(variable))
(state.storage.intern(Type::Variable(variable)), kind)
} else {
let level = state
.type_scope
.lookup_implicit(implicit.node, implicit.id)
.expect("invariant violated: TypeScope::bind_implicit");
let variable = Variable::Bound(level);
state.storage.intern(Type::Variable(variable))
};
let kind = state
.type_scope
.lookup_implicit_kind(implicit.node, implicit.id)
.expect("invariant violated: TypeScope::bind_implicit");

let k = state
.type_scope
.lookup_implicit_kind(implicit.node, implicit.id)
.expect("invariant violated: TypeScope::bind_implicit");
let variable = Variable::Bound(level, kind);
(state.storage.intern(Type::Variable(variable)), kind)
};

(t, k)
}
Expand Down Expand Up @@ -508,9 +508,7 @@ where
Type::Unification(unification_id) => state.unification.get(unification_id).kind,

Type::Variable(ref variable) => match variable {
Variable::Bound(level) => {
state.type_scope.kinds.get(*level).copied().unwrap_or(unknown)
}
Variable::Bound(_, kind) => *kind,
Variable::Skolem(_, kind) => *kind,
Variable::Free(_) => unknown,
},
Expand Down
25 changes: 15 additions & 10 deletions compiler-core/checking/src/algorithm/quantify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::algorithm::constraint::{self, ConstraintApplication};
use crate::algorithm::fold::Zonk;
use crate::algorithm::state::{CheckContext, CheckState};
use crate::algorithm::substitute::{ShiftBound, SubstituteUnification, UniToLevel};
use crate::core::{Class, ForallBinder, Instance, RowType, Type, TypeId, debruijn};
use crate::core::{Class, ForallBinder, Instance, RowType, Type, TypeId, Variable, debruijn};

pub fn quantify(state: &mut CheckState, id: TypeId) -> Option<(TypeId, debruijn::Size)> {
let graph = collect_unification(state, id);
Expand Down Expand Up @@ -47,7 +47,7 @@ pub fn quantify(state: &mut CheckState, id: TypeId) -> Option<(TypeId, debruijn:
let binder = ForallBinder { visible: false, name, level, kind };
quantified = state.storage.intern(Type::Forall(binder, quantified));

substitutions.insert(id, level);
substitutions.insert(id, (level, kind));
}

let quantified = SubstituteUnification::on(&substitutions, state, quantified);
Expand Down Expand Up @@ -219,9 +219,11 @@ pub fn quantify_class(state: &mut CheckState, class: &mut Class) -> Option<debru

let mut substitutions = UniToLevel::default();
for (index, &id) in unsolved.iter().rev().enumerate() {
let kind = state.unification.get(id).kind;
let kind = ShiftBound::on(state, kind, size.0);
let index = debruijn::Index(index as u32);
let level = index.to_level(size)?;
substitutions.insert(id, level);
substitutions.insert(id, (level, kind));
}

let type_variable_kinds = class.type_variable_kinds.iter().map(|&kind| {
Expand Down Expand Up @@ -273,16 +275,14 @@ pub fn quantify_instance(state: &mut CheckState, instance: &mut Instance) -> Opt

let mut substitutions = UniToLevel::default();
for (index, &id) in unsolved.iter().rev().enumerate() {
let kind = state.unification.get(id).kind;
let kind = ShiftBound::on(state, kind, size.0);
let index = debruijn::Index(index as u32);
let level = index.to_level(size)?;
substitutions.insert(id, level);
substitutions.insert(id, (level, kind));
}

let kind_variables = substitutions.iter().map(|(&id, &level)| {
let kind = state.unification.get(id).kind;
(level, kind)
});

let kind_variables = substitutions.values().copied();
let kind_variables = kind_variables.sorted_by_key(|(level, _)| *level);
let kind_variables = kind_variables.map(|(_, kind)| kind).collect_vec();

Expand Down Expand Up @@ -418,7 +418,12 @@ pub fn collect_unification_into(graph: &mut UniGraph, state: &mut CheckState, id
let entry = state.unification.get(unification_id);
aux(graph, state, entry.kind, Some(unification_id));
}
Type::Variable(_) => (),
Type::Variable(ref variable) => match variable {
Variable::Bound(_, kind) | Variable::Skolem(_, kind) => {
aux(graph, state, *kind, dependent);
}
Variable::Free(_) => {}
},
Type::Unknown => (),
}
}
Expand Down
Loading