Skip to content

Commit 547286b

Browse files
committed
Improve constraint generalisation algorithm
1 parent 320dd2f commit 547286b

File tree

10 files changed

+200
-13
lines changed

10 files changed

+200
-13
lines changed

compiler-core/checking/src/algorithm/quantify.rs

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ use petgraph::visit::{DfsPostOrder, Reversed};
99
use rustc_hash::FxHashSet;
1010
use smol_str::SmolStrBuilder;
1111

12-
use crate::ExternalQueries;
1312
use crate::algorithm::constraint::{self, ConstraintApplication};
1413
use crate::algorithm::fold::Zonk;
1514
use crate::algorithm::state::{CheckContext, CheckState};
1615
use crate::algorithm::substitute::{ShiftBound, SubstituteUnification, UniToLevel};
1716
use crate::core::{Class, ForallBinder, Instance, RowType, Type, TypeId, Variable, debruijn};
17+
use crate::{ExternalQueries, safe_loop};
1818

1919
pub fn quantify(state: &mut CheckState, id: TypeId) -> Option<(TypeId, debruijn::Size)> {
2020
let graph = collect_unification(state, id);
@@ -96,28 +96,25 @@ where
9696
return Ok(Some(quantified_with_constraints));
9797
}
9898

99-
let unsolved_graph = collect_unification(state, type_id);
100-
let unsolved_nodes: FxHashSet<u32> = unsolved_graph.nodes().collect();
101-
102-
let mut valid: FxHashSet<TypeId> = FxHashSet::default();
103-
let mut ambiguous = vec![];
99+
let mut pending = vec![];
104100
let mut unsatisfied = vec![];
105101

106102
for constraint in constraints {
107103
let constraint = Zonk::on(state, constraint);
108-
let unsolved_graph = collect_unification(state, constraint);
109-
if unsolved_graph.node_count() == 0 {
104+
let unification: FxHashSet<u32> = collect_unification(state, constraint).nodes().collect();
105+
if unification.is_empty() {
110106
unsatisfied.push(constraint);
111-
} else if unsolved_graph.nodes().all(|unification| unsolved_nodes.contains(&unification)) {
112-
valid.insert(constraint);
113107
} else {
114-
ambiguous.push(constraint);
108+
pending.push((constraint, unification));
115109
}
116110
}
117111

112+
let in_signature = collect_unification(state, type_id).nodes().collect();
113+
let (generalised, ambiguous) = classify_constraints_by_reachability(pending, in_signature);
114+
118115
// Subtle: stable ordering for consistent output
119-
let valid = valid.into_iter().sorted().collect_vec();
120-
let minimized = minimize_by_superclasses(state, context, valid)?;
116+
let generalised = generalised.into_iter().sorted().collect_vec();
117+
let minimized = minimize_by_superclasses(state, context, generalised)?;
121118

122119
let constrained_type = minimized.into_iter().rfold(type_id, |constrained, constraint| {
123120
state.storage.intern(Type::Constrained(constraint, constrained))
@@ -186,6 +183,40 @@ where
186183
.collect())
187184
}
188185

186+
/// Classifies constraints as valid or ambiguous based on variable reachability.
187+
///
188+
/// A constraint is valid if its variables are transitively reachable from the
189+
/// signature variables. The algorithm uses fixed-point iteration. As long as
190+
/// a constraint shares any variable with the reachable set, all its variables
191+
/// become reachable.
192+
fn classify_constraints_by_reachability(
193+
pending: Vec<(TypeId, FxHashSet<u32>)>,
194+
in_signature: FxHashSet<u32>,
195+
) -> (FxHashSet<TypeId>, Vec<TypeId>) {
196+
let mut reachable = in_signature;
197+
let mut valid = FxHashSet::default();
198+
let mut remaining = pending;
199+
200+
safe_loop! {
201+
let (connected, disconnected): (Vec<_>, Vec<_>) =
202+
remaining.into_iter().partition(|(_, unification)| {
203+
unification.iter().any(|variable| reachable.contains(variable))
204+
});
205+
206+
if connected.is_empty() {
207+
let ambiguous = disconnected.into_iter().map(|(id, _)| id).collect();
208+
return (valid, ambiguous);
209+
}
210+
211+
for (constraint, unification) in connected {
212+
valid.insert(constraint);
213+
reachable.extend(unification);
214+
}
215+
216+
remaining = disconnected;
217+
}
218+
}
219+
189220
fn generate_type_name(id: u32) -> smol_str::SmolStr {
190221
let mut builder = SmolStrBuilder::default();
191222
write!(builder, "t{id}").unwrap();
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module Main where
2+
3+
import Prim.Row as Prim.Row
4+
5+
foreign import unsafeCoerce :: forall a b. a -> b
6+
7+
merge
8+
:: forall r1 r2 r3 r4
9+
. Prim.Row.Union r1 r2 r3
10+
=> Prim.Row.Nub r3 r4
11+
=> Record r1
12+
-> Record r2
13+
-> Record r4
14+
merge _ _ = unsafeCoerce {}
15+
16+
-- This should generalise the unsolved Union and Nub constraints
17+
-- rather than emitting AmbiguousConstraint errors
18+
a = merge { a: 123 }
19+
20+
-- Fully applied, should resolve to a concrete record type
21+
b = a { b: 123 }

tests-integration/fixtures/checking/210_row_constraint_generalization/Main.snap

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module Main where
2+
3+
import Prim.Row as Prim.Row
4+
5+
foreign import unsafeCoerce :: forall a b. a -> b
6+
7+
fromUnion
8+
:: forall r1 r2 r3
9+
. Prim.Row.Union r1 r2 r3
10+
=> Record r3
11+
-> Int
12+
fromUnion _ = unsafeCoerce 0
13+
14+
test = fromUnion

tests-integration/fixtures/checking/211_row_constraint_result/Main.snap

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module Main where
2+
3+
import Prim.Row as Prim.Row
4+
5+
foreign import unsafeCoerce :: forall a b. a -> b
6+
7+
chainedUnion
8+
:: forall r1 r2 r3 r4 r5
9+
. Prim.Row.Union r1 r2 r3
10+
=> Prim.Row.Union r3 r4 r5
11+
=> Record r1
12+
-> Record r5
13+
chainedUnion _ = unsafeCoerce {}
14+
15+
test = chainedUnion { x: 1 }

tests-integration/fixtures/checking/212_row_constraint_chained/Main.snap

Lines changed: 18 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
module Main where
2+
3+
import Prim.Row as Prim.Row
4+
5+
foreign import unsafeCoerce :: forall a b. a -> b
6+
7+
multiMerge
8+
:: forall r1 r2 r3 r4 r5 r6
9+
. Prim.Row.Union r1 r2 r3
10+
=> Prim.Row.Nub r3 r4
11+
=> Prim.Row.Union r4 r5 r6
12+
=> Record r1
13+
-> Record r5
14+
-> Record r6
15+
multiMerge _ _ = unsafeCoerce {}
16+
17+
test1 = multiMerge { a: 1 }
18+
19+
test2 = multiMerge { a: 1 } { b: 2 }

tests-integration/fixtures/checking/213_row_constraint_multiple/Main.snap

Lines changed: 27 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests-integration/tests/checking/generated.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,11 @@ fn run_test(folder: &str, file: &str) {
445445
#[rustfmt::skip] #[test] fn test_208_int_add_constraint_main() { run_test("208_int_add_constraint", "Main"); }
446446

447447
#[rustfmt::skip] #[test] fn test_209_int_cons_constraint_main() { run_test("209_int_cons_constraint", "Main"); }
448+
449+
#[rustfmt::skip] #[test] fn test_210_row_constraint_generalization_main() { run_test("210_row_constraint_generalization", "Main"); }
450+
451+
#[rustfmt::skip] #[test] fn test_211_row_constraint_result_main() { run_test("211_row_constraint_result", "Main"); }
452+
453+
#[rustfmt::skip] #[test] fn test_212_row_constraint_chained_main() { run_test("212_row_constraint_chained", "Main"); }
454+
455+
#[rustfmt::skip] #[test] fn test_213_row_constraint_multiple_main() { run_test("213_row_constraint_multiple", "Main"); }

0 commit comments

Comments
 (0)