11from __future__ import annotations
22
3+ from collections .abc import Callable
34from dataclasses import dataclass , field
5+ from typing import TypeAlias
46
57from util .recursive_eq import recursive_eq
6- from ..astgen .ast_nodes import (
7- AstNode , AstIdent , AstDeclNode , AstDefine , VarDeclType , VarDeclScope )
8+ from ..astgen .ast_nodes import *
89from ..astgen .astgen import AstGen
910from ..astgen .filtered_walker import FilteredWalker
10- from ..common import BaseLocatedError , StrRegion
11+ from ..common import BaseLocatedError , StrRegion , region_union , RegionUnionArgT
1112
1213
1314@dataclass
@@ -193,7 +194,18 @@ def err(self, msg: str, region: StrRegion):
193194 return NameResolutionError (msg , region , self .src )
194195
195196
197+ class TypecheckError (BaseLocatedError ):
198+ """Errors raised by the typechecker"""
199+
200+
201+ NodeTypecheckFn : TypeAlias = 'Callable[[Typechecker, AstNode], TypeInfo | None]'
202+
203+ _typecheck_dispatch : dict [type [AstNode ], NodeTypecheckFn ] = {}
204+
205+
196206class Typechecker :
207+ _curr_scope : Scope
208+
197209 def __init__ (self , name_resolver : NameResolver ):
198210 self .resolver = name_resolver
199211 self .src = self .resolver .src
@@ -203,16 +215,118 @@ def _init(self):
203215 self .resolver .run ()
204216 self .ast = self .resolver .ast
205217 self .top_scope = self .resolver .top_scope
218+ self ._curr_scope = self .top_scope
206219
207220 def run (self ):
208221 if self .is_ok is None :
209222 return self .is_ok
210- self ._typecheck ()
211- self .is_ok = True
223+ self ._init ()
224+ self ._typecheck (self .ast )
225+ self .is_ok = True # didn't raise any errors
212226 return self .is_ok
213227
214- def _typecheck (self ):
215- walker = FilteredWalker ()
216-
217- self .ast .walk (walker )
218- ...
228+ def _node_typechecker (self , tp = None ):
229+ if tp is None :
230+ assert callable (self )
231+ tp = self # Called as decor in this class
232+
233+ def decor (fn : NodeTypecheckFn ):
234+ _typecheck_dispatch [tp ] = fn
235+ return fn
236+ return decor
237+
238+ def _typecheck (self , n : AstNode ):
239+ try :
240+ fn = _typecheck_dispatch [type (n )]
241+ except KeyError :
242+ fn = type (self )._typecheck_node_fallback
243+ return fn (self , n )
244+
245+ def _typecheck_node_fallback (self , n : AstNode ):
246+ raise NotImplementedError (f"No typechecker function for node "
247+ f"type { type (n ).__name__ } " )
248+
249+ @_node_typechecker (AstProgramNode )
250+ def _typecheck_program (self , n : AstProgramNode ):
251+ self ._typecheck_block (n .statements )
252+
253+ def _typecheck_block (self , block : list [AstNode ]):
254+ for smt in block :
255+ if (tp := self ._typecheck (smt )) is not None :
256+ self .expect_type (tp , VoidType (), smt )
257+
258+ @_node_typechecker (AstDeclNode )
259+ def _typecheck_decl (self , n : AstDeclNode ):
260+ if not n .value : # Nothing to check
261+ return
262+ expect = self ._resolve_scope (n .scope ).declared [n .ident .id ].tp_info
263+ self .expect_type (self ._typecheck (n .value ), expect , n )
264+
265+ @_node_typechecker (AstRepeat )
266+ def _typecheck_repeat (self , n : AstRepeat ):
267+ # For now, we don't differentiate between number/string (as sc doesn't)
268+ self .expect_type (self ._typecheck (n .count ), ValType (), n .count )
269+ self ._typecheck_block (n .body )
270+
271+ @_node_typechecker (AstIf )
272+ def _typecheck_if (self , n : AstIf ):
273+ self .expect_type (self ._typecheck (n .cond ), BoolType (), n .cond )
274+ self ._typecheck_block (n .if_body )
275+ if n .else_body is not None :
276+ self ._typecheck_block (n .else_body )
277+
278+ @_node_typechecker (AstWhile )
279+ def _typecheck_while (self , n : AstWhile ):
280+ self .expect_type (self ._typecheck (n .cond ), BoolType (), n .cond )
281+ self ._typecheck_block (n .body )
282+
283+ @_node_typechecker (AstAssign )
284+ def _typecheck_assign (self , n : AstAssign ): # super tempted to call this _typecheck_ass
285+ if isinstance (n .target , AstIdent ):
286+ target_tp = self ._curr_scope .used [n .target .id ].tp_info
287+ elif isinstance (n .target , AstItem ): # ls[i] = v
288+ target_tp = self ._typecheck (n .target ) # Also checks that `ls` is a list
289+ elif isinstance (n .target , AstAttribute ):
290+ raise self .err ("Setting attributes is currently unsupported" , n .target )
291+ else :
292+ assert 0 , "Unknown simple-assignment type"
293+ if target_tp == ListType ():
294+ raise self .err ("Cannot assign directly to list" , n )
295+ self .expect_type (self ._typecheck (n .source ), target_tp , n )
296+
297+ @_node_typechecker (AstAugAssign )
298+ def _typecheck_aug_assign (self , n : AstAugAssign ):
299+ # TODO: change this when desugaring is implemented
300+ # (for now only +=, only on variables)
301+ if n .op != '+=' :
302+ raise self .err (f"The '{ n .op } ' operator is not implemented" , n )
303+ if not isinstance (n .target , AstIdent ):
304+ raise self .err (f"The '+=' operator is only implemented for variables" , n )
305+ target_tp = self ._curr_scope .used [n .target .id ].tp_info
306+ if target_tp != ValType ():
307+ raise self .err (f"Cannot apply += to { target_tp } " , n )
308+ self .expect_type (self ._typecheck (n .source ), ValType (), n .source )
309+
310+ @_node_typechecker (AstDefine )
311+ def _typecheck_define (self , n : AstDefine ):
312+ # Don't really need to check much here - type is generated from the
313+ # syntax so must be correct. Set _curr_scope and check body
314+ func_info = self ._curr_scope .declared [n .ident .id ]
315+ assert isinstance (func_info , FuncInfo )
316+ old_scope = self ._curr_scope
317+ self ._curr_scope = func_info .subscope
318+ try :
319+ self ._typecheck_block (n .body )
320+ finally :
321+ self ._curr_scope = old_scope
322+
323+ def _resolve_scope (self , scope_tp : VarDeclScope ):
324+ return self .top_scope if scope_tp == VarDeclScope .GLOBAL else self ._curr_scope
325+
326+ def err (self , msg : str , loc : RegionUnionArgT ):
327+ return TypecheckError (msg , region_union (loc ), self .src )
328+
329+ def expect_type (self , actual : TypeInfo , exp : TypeInfo , loc : RegionUnionArgT ):
330+ if exp != actual :
331+ # TODO: maybe better type formatting
332+ raise self .err (f"Expected type { exp } , got type { actual } " , loc )
0 commit comments