Skip to content

Commit 8e6be34

Browse files
committed
Simple type inference
1 parent 3fccf20 commit 8e6be34

10 files changed

+351
-177
lines changed

src/compiler.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ pub struct Compiler {
5757
pub variables: Vec<Variable>,
5858
/// Mapping of variable's name node -> Variable
5959
pub var_resolution: HashMap<NodeId, VarId>,
60-
/// Declarations (commands, aliases, externs), indexed by VarId
60+
/// Declarations (commands, aliases, externs), indexed by DeclId
6161
pub decls: Vec<Box<dyn Command>>,
62+
/// Declaration NodeIds, indexed by DeclId
63+
pub decl_nodes: Vec<NodeId>,
6264
/// Mapping of decl's name node -> Command
6365
pub decl_resolution: HashMap<NodeId, DeclId>,
6466

@@ -95,6 +97,7 @@ impl Compiler {
9597
variables: vec![],
9698
var_resolution: HashMap::new(),
9799
decls: vec![],
100+
decl_nodes: vec![],
98101
decl_resolution: HashMap::new(),
99102

100103
// variables: vec![],
@@ -158,6 +161,7 @@ impl Compiler {
158161
self.variables.extend(name_bindings.variables);
159162
self.var_resolution.extend(name_bindings.var_resolution);
160163
self.decls.extend(name_bindings.decls);
164+
self.decl_nodes.extend(name_bindings.decl_nodes);
161165
self.decl_resolution.extend(name_bindings.decl_resolution);
162166
self.errors.extend(name_bindings.errors);
163167
}

src/resolver.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ pub struct NameBindings {
5858
pub variables: Vec<Variable>,
5959
pub var_resolution: HashMap<NodeId, VarId>,
6060
pub decls: Vec<Box<dyn Command>>,
61+
pub decl_nodes: Vec<NodeId>,
6162
pub decl_resolution: HashMap<NodeId, DeclId>,
6263
pub errors: Vec<SourceError>,
6364
}
@@ -70,6 +71,7 @@ impl NameBindings {
7071
variables: vec![],
7172
var_resolution: HashMap::new(),
7273
decls: vec![],
74+
decl_nodes: vec![],
7375
decl_resolution: HashMap::new(),
7476
errors: vec![],
7577
}
@@ -96,6 +98,8 @@ pub struct Resolver<'a> {
9698
pub var_resolution: HashMap<NodeId, VarId>,
9799
/// Declarations (commands, aliases, etc.), indexed by DeclId
98100
pub decls: Vec<Box<dyn Command>>,
101+
/// Declaration NodeIds, indexed by DeclId
102+
pub decl_nodes: Vec<NodeId>,
99103
/// Mapping of decl's name node -> Command
100104
pub decl_resolution: HashMap<NodeId, DeclId>,
101105
/// Errors encountered during name binding
@@ -111,6 +115,7 @@ impl<'a> Resolver<'a> {
111115
variables: vec![],
112116
var_resolution: HashMap::new(),
113117
decls: vec![],
118+
decl_nodes: vec![],
114119
decl_resolution: HashMap::new(),
115120
errors: vec![],
116121
}
@@ -123,6 +128,7 @@ impl<'a> Resolver<'a> {
123128
variables: self.variables,
124129
var_resolution: self.var_resolution,
125130
decls: self.decls,
131+
decl_nodes: self.decl_nodes,
126132
decl_resolution: self.decl_resolution,
127133
errors: self.errors,
128134
}
@@ -201,7 +207,7 @@ impl<'a> Resolver<'a> {
201207
// TODO: Move node_id param to the end, same as in typechecker
202208
match self.compiler.ast_nodes[node_id.0] {
203209
AstNode::Expr(ref expr) => self.resolve_expr(expr.clone(), node_id),
204-
AstNode::Stmt(ref stmt) => self.resolve_stmt(stmt.clone()),
210+
AstNode::Stmt(ref stmt) => self.resolve_stmt(stmt.clone(), node_id),
205211
AstNode::Params(ref params) => {
206212
for param in params {
207213
if let AstNode::Param { name, .. } = self.compiler.ast_nodes[param.0] {
@@ -294,7 +300,7 @@ impl<'a> Resolver<'a> {
294300
}
295301
}
296302

297-
pub fn resolve_stmt(&mut self, stmt: Stmt) {
303+
pub fn resolve_stmt(&mut self, stmt: Stmt, node_id: NodeId) {
298304
match stmt {
299305
Stmt::Def {
300306
name,
@@ -303,7 +309,7 @@ impl<'a> Resolver<'a> {
303309
block,
304310
} => {
305311
// define the command before the block to enable recursive calls
306-
self.define_decl(name);
312+
self.define_decl(name, node_id);
307313

308314
// making sure the def parameters and body end up in the same scope frame
309315
self.enter_scope(block);
@@ -320,7 +326,7 @@ impl<'a> Resolver<'a> {
320326
new_name,
321327
old_name: _,
322328
} => {
323-
self.define_decl(new_name);
329+
self.define_decl(new_name, node_id);
324330
}
325331
Stmt::Let {
326332
variable_name,
@@ -491,7 +497,7 @@ impl<'a> Resolver<'a> {
491497
self.var_resolution.insert(var_name_id, var_id);
492498
}
493499

494-
pub fn define_decl(&mut self, decl_name_id: NodeId) {
500+
pub fn define_decl(&mut self, decl_name_id: NodeId, decl_node_id: NodeId) {
495501
// TODO: Deduplicate code with define_variable()
496502
let decl_name = self.compiler.get_span_contents(decl_name_id);
497503
let decl_name = trim_decl_name(decl_name).to_vec();
@@ -509,6 +515,8 @@ impl<'a> Resolver<'a> {
509515
self.decls.push(Box::new(decl));
510516
let decl_id = DeclId(self.decls.len() - 1);
511517

518+
self.decl_nodes.push(decl_node_id);
519+
512520
// let the definition of a decl also count as its use
513521
self.decl_resolution.insert(decl_name_id, decl_id);
514522
}

src/snapshots/[email protected]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ input_file: tests/calls.nu
8585
32: string
8686
33: string
8787
34: int
88-
35: any
88+
35: list<any>
8989
36: unknown
9090
37: stream<binary>
9191
38: stream<binary>
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
---
2+
source: src/test.rs
3+
expression: evaluate_example(path)
4+
input_file: tests/calls_invalid.nu
5+
---
6+
==== COMPILER ====
7+
0: Name (4 to 7) "foo"
8+
1: Name (10 to 11) "a"
9+
2: Name (13 to 16) "int"
10+
3: Type(Ref { name: NodeId(2), args: None, optional: false }) (13 to 16)
11+
4: Param { name: NodeId(1), ty: Some(NodeId(3)) } (10 to 16)
12+
5: Params([NodeId(4)]) (8 to 18)
13+
6: Expr(Block(BlockId(0))) (19 to 21)
14+
7: Stmt(Def { name: NodeId(0), params: NodeId(5), in_out_types: None, block: NodeId(6) }) (0 to 21)
15+
8: Name (22 to 25) "foo"
16+
9: Expr(Int) (26 to 27) "1"
17+
10: Expr(Int) (28 to 29) "2"
18+
11: Expr(Call { parts: [NodeId(8), NodeId(9), NodeId(10)] }) (26 to 29)
19+
12: Name (30 to 33) "foo"
20+
13: Expr(String) (34 to 42) ""string""
21+
14: Expr(Call { parts: [NodeId(12), NodeId(13)] }) (34 to 42)
22+
15: Expr(Block(BlockId(1))) (0 to 43)
23+
==== SCOPE ====
24+
0: Frame Scope, node_id: NodeId(15)
25+
decls: [ foo: NodeId(0) ]
26+
1: Frame Scope, node_id: NodeId(6)
27+
variables: [ a: NodeId(1) ]
28+
==== TYPES ====
29+
0: unknown
30+
1: unknown
31+
2: unknown
32+
3: int
33+
4: int
34+
5: forbidden
35+
6: ()
36+
7: ()
37+
8: unknown
38+
9: int
39+
10: unknown
40+
11: ()
41+
12: unknown
42+
13: string
43+
14: ()
44+
15: ()
45+
==== TYPE ERRORS ====
46+
Error (NodeId 11): Expected 1 argument(s), got 2
47+
Error (NodeId 13): Expected int, got string
48+
==== IR ====
49+
register_count: 0
50+
file_count: 0
51+
==== IR ERRORS ====
52+
Error (NodeId 7): node Stmt(Def { name: NodeId(0), params: NodeId(5), in_out_types: None, block: NodeId(6) }) not suported yet

src/snapshots/new_nu_parser__test__node_output@for_break_continue.nu.snap

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,15 @@ input_file: tests/for_break_continue.nu
5757
10: int
5858
11: bool
5959
12: unknown
60-
13: unknown
61-
14: oneof<(), unknown>
60+
13: ()
61+
14: ()
6262
15: int
6363
16: forbidden
6464
17: int
6565
18: bool
6666
19: unknown
67-
20: unknown
68-
21: oneof<(), unknown>
67+
20: ()
68+
21: ()
6969
22: int
7070
23: forbidden
7171
24: int

src/snapshots/new_nu_parser__test__node_output@invalid_if.nu.snap

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ input_file: tests/invalid_if.nu
2121
2: int
2222
3: int
2323
4: int
24-
5: error
25-
6: error
24+
5: int
25+
6: int
2626
==== TYPE ERRORS ====
27-
Error (NodeId 0): The condition for if branch is not a boolean
27+
Error (NodeId 0): Expected bool, got int
2828
==== IR ====
2929
register_count: 0
3030
file_count: 0

src/snapshots/new_nu_parser__test__node_output@let_mismatch.nu.snap

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ input_file: tests/let_mismatch.nu
9191
39: ()
9292
40: ()
9393
==== TYPE ERRORS ====
94-
Error (NodeId 13): initializer does not match declared type
95-
Error (NodeId 26): initializer does not match declared type
96-
Error (NodeId 38): initializer does not match declared type
94+
Error (NodeId 13): Expected string, got int
95+
Error (NodeId 26): Expected list<list<int>>, got list<list<string>>
96+
Error (NodeId 38): Expected record<a: int>, got record<a: string>
9797
==== IR ====
9898
register_count: 0
9999
file_count: 0

src/snapshots/[email protected]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ input_file: tests/loop.nu
4343
13: unknown
4444
14: unknown
4545
15: unknown
46-
16: unknown
46+
16: ()
4747
==== TYPE ERRORS ====
4848
Error (NodeId 15): unsupported ast node 'Stmt(Loop { block: NodeId(14) })' in typechecker
4949
==== IR ====

0 commit comments

Comments
 (0)