Skip to content

Commit fff6268

Browse files
feat(typecheck): Start typechecking (add statements this commit)
1 parent 664fe3e commit fff6268

File tree

2 files changed

+125
-10
lines changed

2 files changed

+125
-10
lines changed

parser/astgen/astgen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def _walk_smt(self, smt: AnyNode) -> list[AstNode]:
112112
elif isinstance(smt, ConditionalBlock):
113113
return self._walk_conditional(smt)
114114
elif isinstance(smt, AssignNode): # Simple assignment
115+
# TODO: maybe separate SetAttr, SetItem, SetVar nodes?
115116
return [AstAssign(smt.region, self._walk_assign_left(smt.target),
116117
self._walk_expr(smt.source))]
117118
elif isinstance(smt, AssignOpNode): # Other (aug.) assignment

parser/typecheck/typecheck.py

Lines changed: 124 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

3+
from collections.abc import Callable
34
from dataclasses import dataclass, field
5+
from typing import TypeAlias
46

57
from util.recursive_eq import recursive_eq
6-
from ..astgen.ast_nodes import (
7-
AstNode, AstIdent, AstDeclNode, AstDefine, VarDeclType, VarDeclScope)
8+
from ..astgen.ast_nodes import *
89
from ..astgen.astgen import AstGen
910
from ..astgen.filtered_walker import FilteredWalker
10-
from ..common import BaseLocatedError, StrRegion
11+
from ..common import BaseLocatedError, StrRegion, region_union, RegionUnionArgT
1112

1213

1314
@dataclass
@@ -193,7 +194,18 @@ def err(self, msg: str, region: StrRegion):
193194
return NameResolutionError(msg, region, self.src)
194195

195196

197+
class TypecheckError(BaseLocatedError):
198+
"""Errors raised by the typechecker"""
199+
200+
201+
NodeTypecheckFn: TypeAlias = 'Callable[[Typechecker, AstNode], TypeInfo | None]'
202+
203+
_typecheck_dispatch: dict[type[AstNode], NodeTypecheckFn] = {}
204+
205+
196206
class Typechecker:
207+
_curr_scope: Scope
208+
197209
def __init__(self, name_resolver: NameResolver):
198210
self.resolver = name_resolver
199211
self.src = self.resolver.src
@@ -203,16 +215,118 @@ def _init(self):
203215
self.resolver.run()
204216
self.ast = self.resolver.ast
205217
self.top_scope = self.resolver.top_scope
218+
self._curr_scope = self.top_scope
206219

207220
def run(self):
208221
if self.is_ok is None:
209222
return self.is_ok
210-
self._typecheck()
211-
self.is_ok = True
223+
self._init()
224+
self._typecheck(self.ast)
225+
self.is_ok = True # didn't raise any errors
212226
return self.is_ok
213227

214-
def _typecheck(self):
215-
walker = FilteredWalker()
216-
217-
self.ast.walk(walker)
218-
...
228+
def _node_typechecker(self, tp=None):
229+
if tp is None:
230+
assert callable(self)
231+
tp = self # Called as decor in this class
232+
233+
def decor(fn: NodeTypecheckFn):
234+
_typecheck_dispatch[tp] = fn
235+
return fn
236+
return decor
237+
238+
def _typecheck(self, n: AstNode):
239+
try:
240+
fn = _typecheck_dispatch[type(n)]
241+
except KeyError:
242+
fn = type(self)._typecheck_node_fallback
243+
return fn(self, n)
244+
245+
def _typecheck_node_fallback(self, n: AstNode):
246+
raise NotImplementedError(f"No typechecker function for node "
247+
f"type {type(n).__name__}")
248+
249+
@_node_typechecker(AstProgramNode)
250+
def _typecheck_program(self, n: AstProgramNode):
251+
self._typecheck_block(n.statements)
252+
253+
def _typecheck_block(self, block: list[AstNode]):
254+
for smt in block:
255+
if (tp := self._typecheck(smt)) is not None:
256+
self.expect_type(tp, VoidType(), smt)
257+
258+
@_node_typechecker(AstDeclNode)
259+
def _typecheck_decl(self, n: AstDeclNode):
260+
if not n.value: # Nothing to check
261+
return
262+
expect = self._resolve_scope(n.scope).declared[n.ident.id].tp_info
263+
self.expect_type(self._typecheck(n.value), expect, n)
264+
265+
@_node_typechecker(AstRepeat)
266+
def _typecheck_repeat(self, n: AstRepeat):
267+
# For now, we don't differentiate between number/string (as sc doesn't)
268+
self.expect_type(self._typecheck(n.count), ValType(), n.count)
269+
self._typecheck_block(n.body)
270+
271+
@_node_typechecker(AstIf)
272+
def _typecheck_if(self, n: AstIf):
273+
self.expect_type(self._typecheck(n.cond), BoolType(), n.cond)
274+
self._typecheck_block(n.if_body)
275+
if n.else_body is not None:
276+
self._typecheck_block(n.else_body)
277+
278+
@_node_typechecker(AstWhile)
279+
def _typecheck_while(self, n: AstWhile):
280+
self.expect_type(self._typecheck(n.cond), BoolType(), n.cond)
281+
self._typecheck_block(n.body)
282+
283+
@_node_typechecker(AstAssign)
284+
def _typecheck_assign(self, n: AstAssign): # super tempted to call this _typecheck_ass
285+
if isinstance(n.target, AstIdent):
286+
target_tp = self._curr_scope.used[n.target.id].tp_info
287+
elif isinstance(n.target, AstItem): # ls[i] = v
288+
target_tp = self._typecheck(n.target) # Also checks that `ls` is a list
289+
elif isinstance(n.target, AstAttribute):
290+
raise self.err("Setting attributes is currently unsupported", n.target)
291+
else:
292+
assert 0, "Unknown simple-assignment type"
293+
if target_tp == ListType():
294+
raise self.err("Cannot assign directly to list", n)
295+
self.expect_type(self._typecheck(n.source), target_tp, n)
296+
297+
@_node_typechecker(AstAugAssign)
298+
def _typecheck_aug_assign(self, n: AstAugAssign):
299+
# TODO: change this when desugaring is implemented
300+
# (for now only +=, only on variables)
301+
if n.op != '+=':
302+
raise self.err(f"The '{n.op}' operator is not implemented", n)
303+
if not isinstance(n.target, AstIdent):
304+
raise self.err(f"The '+=' operator is only implemented for variables", n)
305+
target_tp = self._curr_scope.used[n.target.id].tp_info
306+
if target_tp != ValType():
307+
raise self.err(f"Cannot apply += to {target_tp}", n)
308+
self.expect_type(self._typecheck(n.source), ValType(), n.source)
309+
310+
@_node_typechecker(AstDefine)
311+
def _typecheck_define(self, n: AstDefine):
312+
# Don't really need to check much here - type is generated from the
313+
# syntax so must be correct. Set _curr_scope and check body
314+
func_info = self._curr_scope.declared[n.ident.id]
315+
assert isinstance(func_info, FuncInfo)
316+
old_scope = self._curr_scope
317+
self._curr_scope = func_info.subscope
318+
try:
319+
self._typecheck_block(n.body)
320+
finally:
321+
self._curr_scope = old_scope
322+
323+
def _resolve_scope(self, scope_tp: VarDeclScope):
324+
return self.top_scope if scope_tp == VarDeclScope.GLOBAL else self._curr_scope
325+
326+
def err(self, msg: str, loc: RegionUnionArgT):
327+
return TypecheckError(msg, region_union(loc), self.src)
328+
329+
def expect_type(self, actual: TypeInfo, exp: TypeInfo, loc: RegionUnionArgT):
330+
if exp != actual:
331+
# TODO: maybe better type formatting
332+
raise self.err(f"Expected type {exp}, got type {actual}", loc)

0 commit comments

Comments
 (0)