Skip to content

Commit 08f75c3

Browse files
committed
Correctly add subtype constraints to type vars
1 parent 6f24078 commit 08f75c3

10 files changed

+99
-136
lines changed

src/snapshots/new_nu_parser__test__lexer.snap

Lines changed: 0 additions & 21 deletions
This file was deleted.

src/snapshots/new_nu_parser__test__node_output@binary_ops_mismatch.nu.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ input_file: tests/binary_ops_mismatch.nu
4343
16: bool
4444
==== TYPE ERRORS ====
4545
Error (NodeId 2): Expected string, got float
46-
Error (NodeId 4): Expected list<top>, got string
47-
Error (NodeId 6): Expected list<top>, got float
46+
Error (NodeId 4): Expected list<bottom> <: '0 <: list<top>, got string
47+
Error (NodeId 6): Expected list<bottom> <: '0 <: list<top>, got float
4848
Error (NodeId 5): type mismatch: unsupported append between string and float
4949
Error (NodeId 10): Expected bool, got string
5050
Error (NodeId 12): Expected string, got bool

src/snapshots/new_nu_parser__test__node_output@def_complex.nu.snap renamed to src/snapshots/new_nu_parser__test__node_output@infer_complex.nu.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
source: src/test.rs
33
expression: evaluate_example(path)
4-
input_file: tests/def_complex.nu
4+
input_file: tests/infer_complex.nu
55
---
66
==== COMPILER ====
77
0: Name (4 to 5) "f"
@@ -196,7 +196,7 @@ input_file: tests/def_complex.nu
196196
86: unknown
197197
87: string
198198
88: record<a: float, b: string>
199-
89: record<a: number, b: top>
199+
89: record<a: number, b: string>
200200
90: ()
201201
91: ()
202202
==== IR ====

src/snapshots/new_nu_parser__test__node_output@def_generics.nu.snap renamed to src/snapshots/new_nu_parser__test__node_output@infer_generics.nu.snap

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
source: src/test.rs
33
expression: evaluate_example(path)
4-
input_file: tests/def_generics.nu
4+
input_file: tests/infer_generics.nu
55
---
66
==== COMPILER ====
77
0: Name (4 to 5) "f"
@@ -30,28 +30,12 @@ input_file: tests/def_generics.nu
3030
23: List([NodeId(22)]) (59 to 62)
3131
24: Block(BlockId(0)) (39 to 65)
3232
25: Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(16)), block: NodeId(24) } (0 to 65)
33-
26: Variable (71 to 73) "l1"
34-
27: Name (76 to 77) "f"
35-
28: Int (78 to 79) "1"
36-
29: Call { parts: [NodeId(27), NodeId(28)] } (78 to 79)
37-
30: Let { variable_name: NodeId(26), ty: None, initializer: NodeId(29), is_mutable: false } (67 to 79)
38-
31: Variable (85 to 87) "l2"
39-
32: Name (90 to 91) "f"
40-
33: Int (92 to 93) "2"
41-
34: Call { parts: [NodeId(32), NodeId(33)] } (92 to 93)
42-
35: Let { variable_name: NodeId(31), ty: None, initializer: NodeId(34), is_mutable: false } (81 to 93)
43-
36: Variable (98 to 100) "l3"
44-
37: Name (102 to 106) "list"
45-
38: Name (107 to 113) "number"
46-
39: Type { name: NodeId(38), args: None, optional: false } (107 to 113)
47-
40: TypeArgs([NodeId(39)]) (106 to 114)
48-
41: Type { name: NodeId(37), args: Some(NodeId(40)), optional: false } (102 to 106)
49-
42: Variable (117 to 120) "$l2"
50-
43: Let { variable_name: NodeId(36), ty: Some(NodeId(41)), initializer: NodeId(42), is_mutable: false } (94 to 120)
51-
44: Block(BlockId(1)) (0 to 121)
33+
26: Name (67 to 68) "f"
34+
27: Int (69 to 70) "1"
35+
28: Call { parts: [NodeId(26), NodeId(27)] } (69 to 70)
36+
29: Block(BlockId(1)) (0 to 71)
5237
==== SCOPE ====
53-
0: Frame Scope, node_id: NodeId(44)
54-
variables: [ l1: NodeId(26), l2: NodeId(31), l3: NodeId(36) ]
38+
0: Frame Scope, node_id: NodeId(29)
5539
decls: [ f: NodeId(0) ]
5640
1: Frame Scope, node_id: NodeId(24)
5741
variables: [ x: NodeId(3), z: NodeId(17) ]
@@ -83,25 +67,10 @@ input_file: tests/def_generics.nu
8367
23: list<T>
8468
24: list<T>
8569
25: ()
86-
26: list<top>
87-
27: unknown
88-
28: int
89-
29: list<top>
90-
30: ()
91-
31: list<number>
92-
32: unknown
93-
33: int
94-
34: list<number>
95-
35: ()
96-
36: list<number>
97-
37: unknown
98-
38: unknown
99-
39: number
100-
40: forbidden
101-
41: list<number>
102-
42: list<number>
103-
43: ()
104-
44: ()
70+
26: unknown
71+
27: int
72+
28: list<int>
73+
29: list<int>
10574
==== IR ====
10675
register_count: 0
10776
file_count: 0

src/snapshots/new_nu_parser__test__node_output@def_plus.nu.snap renamed to src/snapshots/new_nu_parser__test__node_output@infer_plus.nu.snap

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
source: src/test.rs
33
expression: evaluate_example(path)
4-
input_file: tests/def_plus.nu
4+
input_file: tests/infer_plus.nu
55
---
66
==== COMPILER ====
77
0: Name (4 to 14) "mysterious"
@@ -21,23 +21,21 @@ input_file: tests/def_plus.nu
2121
14: Block(BlockId(0)) (44 to 46)
2222
15: Def { name: NodeId(0), type_params: Some(NodeId(2)), params: NodeId(7), in_out_types: Some(NodeId(13)), block: NodeId(14) } (0 to 46)
2323
16: Variable (52 to 53) "m"
24-
17: Name (55 to 58) "any"
25-
18: Type { name: NodeId(17), args: None, optional: false } (55 to 58)
26-
19: Name (61 to 71) "mysterious"
27-
20: Int (72 to 73) "0"
28-
21: Call { parts: [NodeId(19), NodeId(20)] } (72 to 73)
29-
22: Let { variable_name: NodeId(16), ty: Some(NodeId(18)), initializer: NodeId(21), is_mutable: false } (48 to 73)
30-
23: Variable (75 to 77) "$m"
31-
24: Plus (78 to 79)
32-
25: String (80 to 85) ""foo""
33-
26: BinaryOp { lhs: NodeId(23), op: NodeId(24), rhs: NodeId(25) } (75 to 85)
34-
27: Variable (86 to 88) "$m"
35-
28: Plus (89 to 90)
36-
29: Int (91 to 94) "123"
37-
30: BinaryOp { lhs: NodeId(27), op: NodeId(28), rhs: NodeId(29) } (86 to 94)
38-
31: Block(BlockId(1)) (0 to 95)
24+
17: Name (56 to 66) "mysterious"
25+
18: Int (67 to 68) "0"
26+
19: Call { parts: [NodeId(17), NodeId(18)] } (67 to 68)
27+
20: Let { variable_name: NodeId(16), ty: None, initializer: NodeId(19), is_mutable: false } (48 to 68)
28+
21: Variable (70 to 72) "$m"
29+
22: Plus (73 to 74)
30+
23: String (75 to 80) ""foo""
31+
24: BinaryOp { lhs: NodeId(21), op: NodeId(22), rhs: NodeId(23) } (70 to 80)
32+
25: Variable (81 to 83) "$m"
33+
26: Plus (84 to 85)
34+
27: Int (86 to 89) "123"
35+
28: BinaryOp { lhs: NodeId(25), op: NodeId(26), rhs: NodeId(27) } (81 to 89)
36+
29: Block(BlockId(1)) (0 to 90)
3937
==== SCOPE ====
40-
0: Frame Scope, node_id: NodeId(31)
38+
0: Frame Scope, node_id: NodeId(29)
4139
variables: [ m: NodeId(16) ]
4240
decls: [ mysterious: NodeId(0) ]
4341
1: Frame Scope, node_id: NodeId(14)
@@ -60,22 +58,22 @@ input_file: tests/def_plus.nu
6058
13: unknown
6159
14: ()
6260
15: ()
63-
16: any
61+
16: bottom
6462
17: unknown
65-
18: any
66-
19: unknown
67-
20: int
68-
21: top
69-
22: ()
70-
23: any
71-
24: forbidden
72-
25: string
73-
26: string
74-
27: any
75-
28: forbidden
76-
29: int
77-
30: number
78-
31: number
63+
18: int
64+
19: bottom
65+
20: ()
66+
21: bottom
67+
22: forbidden
68+
23: string
69+
24: string
70+
25: bottom
71+
26: forbidden
72+
27: int
73+
28: string
74+
29: string
75+
==== TYPE ERRORS ====
76+
Error (NodeId 27): Expected string, got int
7977
==== IR ====
8078
register_count: 0
8179
file_count: 0

src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ input_file: tests/let_mismatch.nu
9292
40: ()
9393
==== TYPE ERRORS ====
9494
Error (NodeId 13): Expected string, got int
95+
Error (NodeId 24): Expected int, got string
96+
Error (NodeId 25): Expected list<int>, got list<string>
9597
Error (NodeId 26): Expected list<list<int>>, got list<list<string>>
9698
Error (NodeId 38): Expected record<a: int>, got record<a: string>
9799
==== IR ====

src/typechecker.rs

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,17 @@ impl<'a> Typechecker<'a> {
215215
let last_node_id = NodeId(last);
216216
self.typecheck_node(last_node_id);
217217

218+
for i in 0..self.type_vars.len() {
219+
let var = &self.type_vars[i];
220+
let bound = var.lower_bound;
221+
let cleaned = self.eliminate_type_vars(bound, TypeVarId(0), true);
222+
self.types[bound.0] = self.types[cleaned.0];
223+
}
224+
218225
for i in 0..self.types.len() {
219-
let res = self.eliminate_type_vars(TypeId(i), TypeVarId(0), false);
220-
if res.0 != i {
221-
self.types[i] = self.types[res.0];
226+
if let Type::Var(var_id) = &self.types[i] {
227+
let bound = self.type_vars[var_id.0].lower_bound;
228+
self.types[i] = self.types[bound.0];
222229
}
223230
}
224231
}
@@ -389,9 +396,10 @@ impl<'a> Typechecker<'a> {
389396
AstNode::True | AstNode::False => BOOL_TYPE,
390397
AstNode::String => STRING_TYPE,
391398
AstNode::List(ref items) => {
392-
// TODO inspect the expected type and infer a union type instead
399+
// TODO infer a union type instead
393400
if let Some(first_id) = items.first() {
394-
self.typecheck_expr(*first_id, TOP_TYPE);
401+
let expected_elem = self.extract_elem_type(expected);
402+
self.typecheck_expr(*first_id, expected_elem.unwrap_or(TOP_TYPE));
395403
let first_type = self.type_of(*first_id);
396404

397405
let mut all_numbers = self.is_type_compatible(first_type, Type::Number);
@@ -709,25 +717,22 @@ impl<'a> Typechecker<'a> {
709717
}
710718
}
711719
AstNode::Append => {
712-
// TODO cache this type
713-
let list_ty = self.push_type(Type::List(TOP_TYPE));
714-
let lhs_type = self.typecheck_expr(lhs, list_ty);
715-
let rhs_type = self.typecheck_expr(rhs, list_ty);
716-
717-
//todo account for any
718-
match (self.types[lhs_type.0], self.types[rhs_type.0]) {
719-
(Type::List(lhs_item), Type::List(rhs_item)) => {
720-
let mut types = HashSet::new();
721-
types.insert(lhs_item);
722-
types.insert(rhs_item);
723-
let common_type = self.create_oneof(types);
724-
self.push_type(Type::List(common_type))
725-
}
726-
(_, Type::Any | Type::Bottom) | (Type::Any | Type::Bottom, _) => ANY_TYPE,
727-
_ => {
728-
self.binary_op_err("append", lhs, op, rhs);
729-
ERROR_TYPE
730-
}
720+
// TODO cache these two types
721+
let top_list = self.push_type(Type::List(TOP_TYPE));
722+
let bottom_list = self.push_type(Type::List(BOTTOM_TYPE));
723+
724+
let res_var = self.new_typevar(bottom_list, top_list);
725+
let res_type = self.push_type(Type::Var(res_var));
726+
let lhs_type = self.typecheck_expr(lhs, res_type);
727+
let rhs_type = self.typecheck_expr(rhs, res_type);
728+
729+
if self.is_subtype(lhs_type, LIST_ANY_TYPE)
730+
&& self.is_subtype(rhs_type, LIST_ANY_TYPE)
731+
{
732+
res_type
733+
} else {
734+
self.binary_op_err("append", lhs, op, rhs);
735+
ERROR_TYPE
731736
}
732737
}
733738
AstNode::Assignment
@@ -1162,6 +1167,18 @@ impl<'a> Typechecker<'a> {
11621167
}
11631168
}
11641169

1170+
/// Given the type for a list, extract the type of its elements
1171+
fn extract_elem_type(&mut self, list_ty: TypeId) -> Option<TypeId> {
1172+
match self.types[list_ty.0] {
1173+
Type::List(elem) => Some(elem),
1174+
Type::Top => Some(TOP_TYPE),
1175+
Type::Bottom => Some(BOTTOM_TYPE),
1176+
Type::Any => Some(ANY_TYPE),
1177+
Type::Unknown => Some(UNKNOWN_TYPE),
1178+
_ => None,
1179+
}
1180+
}
1181+
11651182
fn set_node_type_id(&mut self, node_id: NodeId, type_id: TypeId) {
11661183
self.node_types[node_id.0] = type_id;
11671184
}
@@ -1223,7 +1240,10 @@ impl<'a> Typechecker<'a> {
12231240
(Type::Var(var_id), _) => {
12241241
let lb = self.type_vars[var_id.0].lower_bound;
12251242
let ub = self.type_vars[var_id.0].upper_bound;
1226-
let new_ub = self.create_intersection(ub, supe_id);
1243+
let mut types = HashSet::new();
1244+
types.insert(ub);
1245+
types.insert(supe_id);
1246+
let new_ub = self.create_allof(types);
12271247
// Prevent forward references/cycles
12281248
let new_ub = self.eliminate_type_vars(new_ub, var_id, true);
12291249

@@ -1241,7 +1261,10 @@ impl<'a> Typechecker<'a> {
12411261
(_, Type::Var(var_id)) => {
12421262
let lb = self.type_vars[var_id.0].lower_bound;
12431263
let ub = self.type_vars[var_id.0].upper_bound;
1244-
let new_lb = self.create_intersection(lb, sub_id);
1264+
let mut types = HashSet::new();
1265+
types.insert(lb);
1266+
types.insert(sub_id);
1267+
let new_lb = self.create_oneof(types);
12451268
// Prevent forward references/cycles
12461269
let new_lb = self.eliminate_type_vars(new_lb, var_id, false);
12471270

@@ -1598,6 +1621,8 @@ impl<'a> Typechecker<'a> {
15981621
let mut flattened = HashSet::new();
15991622
for ty_id in types {
16001623
match self.types[ty_id.0] {
1624+
Type::Top | Type::Any | Type::Unknown => return ty_id,
1625+
Type::Bottom => {}
16011626
Type::OneOf(id) => {
16021627
flattened.extend(&self.oneof_types[id.0]);
16031628
}
@@ -1846,11 +1871,4 @@ impl<'a> Typechecker<'a> {
18461871

18471872
self.create_oneof(inters)
18481873
}
1849-
1850-
fn create_intersection(&mut self, lhs_id: TypeId, rhs_id: TypeId) -> TypeId {
1851-
let mut types = HashSet::new();
1852-
types.insert(lhs_id);
1853-
types.insert(rhs_id);
1854-
self.create_allof(types)
1855-
}
18561874
}
Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,4 @@ def f<T> [ x: T ] : nothing -> list<T> {
33
[$z]
44
}
55

6-
let l1 = f 1
7-
8-
let l2 = f 2
9-
let l3: list<number> = $l2
6+
f 1
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
def mysterious<T> [ x: int ] : nothing -> T {}
22

3-
let m: any = mysterious 0
3+
let m = mysterious 0
44

55
$m + "foo"
66
$m + 123

0 commit comments

Comments
 (0)