diff --git a/src/parser.rs b/src/parser.rs index 0f0aed8..fa2512c 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -138,6 +138,9 @@ pub enum AstNode { name: NodeId, ty: Option, }, + InOutTypes(Vec), + /// Input/output type pair for a command + InOutType(NodeId, NodeId), Closure { params: Option, block: NodeId, @@ -987,6 +990,52 @@ impl Parser { } } + pub fn in_out_type(&mut self) -> NodeId { + let _span = span!(); + let span_start = self.position(); + + let in_ty = self.typename(); + self.thin_arrow(); + let out_ty = self.typename(); + + let span_end = self.position(); + self.create_node(AstNode::InOutType(in_ty, out_ty), span_start, span_end) + } + + pub fn in_out_types(&mut self) -> NodeId { + let _span = span!(); + self.colon(); + + if self.is_lsquare() { + let span_start = self.position(); + + self.tokens.advance(); + + let mut output = vec![]; + while self.has_tokens() { + if self.is_rsquare() { + break; + } + + if self.is_comma() { + self.tokens.advance(); + continue; + } + + output.push(self.in_out_type()); + } + + self.rsquare(); + let span_end = self.position(); + + self.create_node(AstNode::InOutTypes(output), span_start, span_end) + } else { + let ty = self.in_out_type(); + let span = self.compiler.get_span(ty); + self.create_node(AstNode::InOutTypes(vec![ty]), span.start, span.end) + } + } + pub fn def_statement(&mut self) -> NodeId { let _span = span!(); let span_start = self.position(); @@ -1002,6 +1051,11 @@ impl Parser { }; let params = self.signature_params(ParamsContext::Squares); + let return_ty = if self.is_colon() { + Some(self.in_out_types()) + } else { + None + }; let block = self.block(BlockContext::Curlies); let span_end = self.get_span_end(block); @@ -1010,7 +1064,7 @@ impl Parser { AstNode::Def { name, params, - return_ty: None, + return_ty, block, }, span_start, @@ -1571,6 +1625,14 @@ impl Parser { } } + pub fn thin_arrow(&mut self) { + if self.is_thin_arrow() { + self.tokens.advance(); + } else { + self.error("expected: thin arrow '->'"); + } + } + pub fn colon(&mut self) { if self.is_colon() { self.tokens.advance(); diff --git a/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap b/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap new file mode 100644 index 0000000..7c27ecd --- /dev/null +++ b/src/snapshots/new_nu_parser__test__node_output@def_return_type.nu.snap @@ -0,0 +1,86 @@ +--- +source: src/test.rs +expression: evaluate_example(path) +input_file: tests/def_return_type.nu +--- +==== COMPILER ==== +0: Name (4 to 7) "foo" +1: Params([]) (8 to 11) +2: Name (14 to 21) "nothing" +3: Type { name: NodeId(2), params: None, optional: false } (14 to 21) +4: Name (25 to 29) "list" +5: Name (30 to 33) "any" +6: Type { name: NodeId(5), params: None, optional: false } (30 to 33) +7: Params([NodeId(6)]) (29 to 34) +8: Type { name: NodeId(4), params: Some(NodeId(7)), optional: false } (25 to 29) +9: InOutType(NodeId(3), NodeId(8)) (14 to 35) +10: InOutTypes([NodeId(9)]) (14 to 35) +11: List([]) (37 to 38) +12: Block(BlockId(0)) (35 to 41) +13: Def { name: NodeId(0), params: NodeId(1), return_ty: Some(NodeId(10)), block: NodeId(12) } (0 to 41) +14: Name (46 to 49) "bar" +15: Params([]) (50 to 53) +16: Name (58 to 64) "string" +17: Type { name: NodeId(16), params: None, optional: false } (58 to 64) +18: Name (68 to 72) "list" +19: Name (73 to 79) "string" +20: Type { name: NodeId(19), params: None, optional: false } (73 to 79) +21: Params([NodeId(20)]) (72 to 80) +22: Type { name: NodeId(18), params: Some(NodeId(21)), optional: false } (68 to 72) +23: InOutType(NodeId(17), NodeId(22)) (58 to 80) +24: Name (82 to 85) "int" +25: Type { name: NodeId(24), params: None, optional: false } (82 to 85) +26: Name (89 to 93) "list" +27: Name (94 to 97) "int" +28: Type { name: NodeId(27), params: None, optional: false } (94 to 97) +29: Params([NodeId(28)]) (93 to 98) +30: Type { name: NodeId(26), params: Some(NodeId(29)), optional: false } (89 to 93) +31: InOutType(NodeId(25), NodeId(30)) (82 to 99) +32: InOutTypes([NodeId(23), NodeId(31)]) (56 to 101) +33: List([]) (103 to 104) +34: Block(BlockId(1)) (101 to 107) +35: Def { name: NodeId(14), params: NodeId(15), return_ty: Some(NodeId(32)), block: NodeId(34) } (42 to 107) +36: Block(BlockId(2)) (0 to 108) +==== SCOPE ==== +0: Frame Scope, node_id: NodeId(36) + decls: [ bar: NodeId(14), foo: NodeId(0) ] +1: Frame Scope, node_id: NodeId(12) (empty) +2: Frame Scope, node_id: NodeId(34) (empty) +==== TYPES ==== +0: unknown +1: forbidden +2: unknown +3: unknown +4: unknown +5: unknown +6: any +7: forbidden +8: unknown +9: unknown +10: unknown +11: list +12: list +13: () +14: unknown +15: forbidden +16: unknown +17: unknown +18: unknown +19: unknown +20: string +21: forbidden +22: unknown +23: unknown +24: unknown +25: unknown +26: unknown +27: unknown +28: int +29: forbidden +30: unknown +31: unknown +32: unknown +33: list +34: list +35: () +36: () diff --git a/src/typechecker.rs b/src/typechecker.rs index 7a3717d..c5d4ae5 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -626,10 +626,46 @@ impl<'a> Typechecker<'a> { &mut self, name: NodeId, params: NodeId, - _return_ty: Option, + return_ty: Option, block: NodeId, node_id: NodeId, ) { + let return_ty = return_ty + .map(|ty| { + let AstNode::InOutTypes(types) = self.compiler.get_node(ty) else { + panic!("internal error: return type is not a return type"); + }; + types + .iter() + .map(|ty| { + let AstNode::InOutType(in_ty, out_ty) = self.compiler.get_node(*ty) else { + panic!("internal error: return type is not a return type"); + }; + let AstNode::Type { + name: in_name, + params: in_params, + optional: in_optional, + } = *self.compiler.get_node(*in_ty) + else { + panic!("internal error: type is not a type"); + }; + let AstNode::Type { + name: out_name, + params: out_params, + optional: out_optional, + } = *self.compiler.get_node(*out_ty) + else { + panic!("internal error: type is not a type"); + }; + InOutType { + in_type: self.typecheck_type(in_name, in_params, in_optional), + out_type: self.typecheck_type(out_name, out_params, out_optional), + } + }) + .collect::>() + }) + .unwrap_or_default(); + self.typecheck_node(params); self.typecheck_node(block); self.set_node_type_id(node_id, NONE_TYPE); @@ -641,10 +677,15 @@ impl<'a> Typechecker<'a> { .get(&name) .expect("missing declared decl"); - self.decl_types[decl_id.0] = vec![InOutType { - in_type: ANY_TYPE, - out_type: self.type_id_of(block), - }]; + if return_ty.is_empty() { + self.decl_types[decl_id.0] = vec![InOutType { + in_type: ANY_TYPE, + out_type: self.type_id_of(block), + }]; + } else { + // TODO check that block output type matches expected type + self.decl_types[decl_id.0] = return_ty; + } } fn typecheck_alias(&mut self, new_name: NodeId, old_name: NodeId, node_id: NodeId) { diff --git a/tests/def_return_type.nu b/tests/def_return_type.nu new file mode 100644 index 0000000..1a75f31 --- /dev/null +++ b/tests/def_return_type.nu @@ -0,0 +1,2 @@ +def foo [ ] : nothing -> list { [] } +def bar [ ] : [ string -> list, int -> list ] { [] }