Skip to content

Commit 044b418

Browse files
Merge pull request #68 from MarcellPerger1/add-typechecking
Fix name resolving bug, add tests
2 parents 22f9ee4 + 7f0e52c commit 044b418

File tree

8 files changed

+222
-59
lines changed

8 files changed

+222
-59
lines changed

.idea/inspectionProfiles/project_inspections.xml

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

parser/astgen/ast_node.py

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22

33
from dataclasses import dataclass
44
from enum import Enum
5-
from typing import Callable, TypeAlias, Iterable
5+
from typing import Callable, TypeAlias, Iterable, TypeVar
66

7+
from util import flatten_force
78
from ..common import HasRegion, StrRegion
89

910
__all__ = [
1011
"AstNode", "AstProgramNode", "VarDeclScope", "VarDeclType", "AstDeclNode",
1112
"AstRepeat", "AstIf", "AstWhile", "AstAssign", "AstAugAssign", "AstDefine",
1213
"AstNumber", "AstString", "AstAnyName", "AstIdent", "AstAttrName",
1314
"AstListLiteral", "AstAttribute", "AstItem", "AstCall", "AstOp", "AstBinOp",
14-
"AstUnaryOp", 'walk_ast', 'WalkableT', 'WalkerFnT', 'WalkerCallType'
15+
"AstUnaryOp", 'walk_ast', 'WalkableT', 'WalkerFnT', 'WalkerCallType',
16+
"FilteredWalker"
1517
]
1618

1719

@@ -48,7 +50,6 @@ def _walk_obj_members(cls, o: WalkableT, fn: WalkerFnT):
4850
if o is None:
4951
return
5052
if isinstance(o, AstNode):
51-
# noinspection PyProtectedMember
5253
return o._walk_members(fn)
5354
try:
5455
it = iter(o)
@@ -75,6 +76,122 @@ def walk_multiple_objects(cls, fn: WalkerFnT, objs: Iterable[WalkableT]):
7576
walk_ast = AstNode.walk_obj
7677

7778

79+
# region <FilteredWalker>
80+
WT = TypeVar('WT', bound=WalkableT)
81+
VT = TypeVar('VT')
82+
SpecificCbT = Callable[[WT], bool | None]
83+
SpecificCbsDict = dict[type[WT] | type, list[Callable[[WT], bool | None]]]
84+
BothCbT = Callable[[WT, WalkerCallType], bool | None]
85+
BothCbsDict = dict[type[WT] | type, list[Callable[[WT, WalkerCallType], bool | None]]]
86+
87+
88+
class WalkerFilterRegistry:
89+
def __init__(self, enter_cbs: SpecificCbsDict = (),
90+
exit_cbs: SpecificCbsDict = (),
91+
both_sbc: BothCbsDict = ()):
92+
self.enter_cbs: SpecificCbsDict = dict(enter_cbs) # Copy them,
93+
self.exit_cbs: SpecificCbsDict = dict(exit_cbs) # also converts default () -> {}
94+
self.both_cbs: BothCbsDict = dict(both_sbc)
95+
96+
def copy(self):
97+
return WalkerFilterRegistry(self.enter_cbs, self.exit_cbs, self.both_cbs)
98+
99+
def register_both(self, t: type[WT], fn: Callable[[WT, WalkerCallType], bool | None]):
100+
self.both_cbs.setdefault(t, []).append(fn)
101+
return self
102+
103+
def register_enter(self, t: type[WT], fn: Callable[[WT], bool | None]):
104+
self.enter_cbs.setdefault(t, []).append(fn)
105+
return self
106+
107+
def register_exit(self, t: type[WT], fn: Callable[[WT], bool | None]):
108+
self.exit_cbs.setdefault(t, []).append(fn)
109+
return self
110+
111+
def __call__(self, *args, **kwargs):
112+
return self
113+
114+
def on_enter(self, *tps: type[WT] | type):
115+
"""Decorator version of register_enter."""
116+
def decor(fn: SpecificCbT):
117+
for t in tps:
118+
self.register_enter(t, fn)
119+
return fn
120+
return decor
121+
122+
def on_exit(self, *tps: type[WT] | type):
123+
"""Decorator version of register_exit."""
124+
def decor(fn: SpecificCbT):
125+
for t in tps:
126+
self.register_exit(t, fn)
127+
return fn
128+
return decor
129+
130+
def on_both(self, *tps: type[WT] | type):
131+
"""Decorator version of register_both."""
132+
def decor(fn: BothCbT):
133+
for t in tps:
134+
self.register_both(t, fn)
135+
return fn
136+
return decor
137+
138+
139+
class FilteredWalker(WalkerFilterRegistry):
140+
def __init__(self):
141+
cls_reg = self.class_registry()
142+
super().__init__(cls_reg.enter_cbs, cls_reg.exit_cbs, cls_reg.both_cbs)
143+
144+
@classmethod
145+
def class_registry(cls) -> WalkerFilterRegistry:
146+
return WalkerFilterRegistry()
147+
148+
@classmethod
149+
def create_cls_registry(cls, fn=None):
150+
"""Create a class-level registry that can be added to using decorators.
151+
152+
This can be used in two ways (at the top of your class)::
153+
154+
# MUST be this name
155+
class_registry = FilteredWalker.create_cls_registry()
156+
157+
or::
158+
159+
@classmethod
160+
@FilteredWalker.create_cls_registry
161+
def class_registry(cls): # MUST be this name
162+
pass
163+
164+
and when registering methods::
165+
166+
@class_registry.on_enter(AstDefine)
167+
def enter_define(self, ...):
168+
...
169+
170+
The restrictions on name are because we have no other way of detecting
171+
it (without metaclass dark magic) as we can't refer to the class while
172+
its namespace is being evaluated
173+
"""
174+
if fn is not None and (parent := fn(cls)) is not None:
175+
return WalkerFilterRegistry.copy(parent)
176+
return WalkerFilterRegistry()
177+
178+
def __call__(self, o: WalkableT, call_type: WalkerCallType):
179+
result = None
180+
# Call more specific ones first
181+
specific_cbs = self.enter_cbs if call_type == WalkerCallType.PRE else self.exit_cbs
182+
for fn in self._get_funcs(specific_cbs, type(o)):
183+
result = fn(o) or result
184+
for fn in self._get_funcs(self.both_cbs, type(o)):
185+
result = fn(o, call_type) or result
186+
return result
187+
188+
@classmethod
189+
def _get_funcs(cls, mapping: dict[type[WT] | type, list[VT]], tp: type[WT]) -> list[VT]:
190+
"""Also looks at superclasses/MRO"""
191+
return flatten_force(mapping.get(sub, []) for sub in tp.mro())
192+
# endregion
193+
194+
78195
@dataclass
79196
class AstProgramNode(AstNode):
80197
name = 'program'

parser/astgen/errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from parser.common import BaseParseError, BaseLocatedError
3+
from ..common import BaseParseError, BaseLocatedError
44

55

66
class AstParseError(BaseParseError):

parser/lexer/errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from parser.common import BaseParseError, BaseLocatedError
3+
from ..common import BaseParseError, BaseLocatedError
44

55

66
class TokenizerError(BaseParseError):

parser/typecheck/typecheck.py

Lines changed: 45 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,13 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, field
4-
from typing import Callable, TypeVar
5-
6-
from parser.astgen.ast_node import (
7-
AstNode, walk_ast, WalkableT, WalkerCallType, AstIdent, AstDeclNode,
8-
AstDefine, VarDeclType, VarDeclScope)
9-
from parser.astgen.astgen import AstGen
10-
from parser.common import BaseLocatedError, StrRegion
11-
from util import flatten_force
12-
13-
WT = TypeVar('WT', bound=WalkableT)
14-
VT = TypeVar('VT')
15-
16-
17-
class FilteredWalker:
18-
def __init__(self):
19-
self.enter_cbs: dict[type[WT] | type, list[Callable[[WT], bool | None]]] = {}
20-
self.exit_cbs: dict[type[WT] | type, list[Callable[[WT], bool | None]]] = {}
21-
self.both_cbs: dict[type[WT] | type, list[
22-
Callable[[WT, WalkerCallType], bool | None]]] = {}
23-
24-
def register_both(self, t: type[WT], fn: Callable[[WT, WalkerCallType], bool | None]):
25-
self.both_cbs.setdefault(t, []).append(fn)
26-
return self
27-
28-
def register_enter(self, t: type[WT], fn: Callable[[WT], bool | None]):
29-
self.enter_cbs.setdefault(t, []).append(fn)
30-
return self
31-
32-
def register_exit(self, t: type[WT], fn: Callable[[WT], bool | None]):
33-
self.exit_cbs.setdefault(t, []).append(fn)
34-
return self
35-
36-
def __call__(self, o: WalkableT, call_type: WalkerCallType):
37-
result = None
38-
# Call more specific ones first
39-
specific_cbs = self.enter_cbs if call_type == WalkerCallType.PRE else self.exit_cbs
40-
for fn in self._get_funcs(specific_cbs, type(o)):
41-
result = fn(o) or result
42-
for fn in self._get_funcs(self.both_cbs, type(o)):
43-
result = fn(o, call_type) or result
44-
return result
454

46-
@classmethod
47-
def _get_funcs(cls, mapping: dict[type[WT] | type, list[VT]], tp: type[WT]) -> list[VT]:
48-
"""Also looks at superclasses/MRO"""
49-
return flatten_force(mapping.get(sub, []) for sub in tp.mro())
5+
from util.recursive_eq import recursive_eq
6+
from ..astgen.ast_node import (
7+
AstNode, walk_ast, AstIdent, AstDeclNode, AstDefine, VarDeclType,
8+
VarDeclScope, FilteredWalker)
9+
from ..astgen.astgen import AstGen
10+
from ..common import BaseLocatedError, StrRegion
5011

5112

5213
@dataclass
@@ -124,6 +85,9 @@ class Scope:
12485
(so type codegen/type-checker knows what each AstIdent refers to)"""
12586

12687

88+
Scope.__eq__ = recursive_eq(Scope.__eq__)
89+
90+
12791
class NameResolutionError(BaseLocatedError):
12892
pass
12993

@@ -191,14 +155,15 @@ def enter_fn_decl(fn: AstDefine):
191155
raise self.err("Function already declared", fn.ident.region)
192156
subscope = Scope()
193157
params: list[ParamInfo] = []
194-
for tp, param in fn.params:
195-
if tp.id not in PARAM_TYPES:
196-
raise self.err("Unknown parameter type", tp.region)
197-
if param.id in subscope.declared:
198-
raise self.err("There is already a parameter of this name", param.region)
199-
tp = BoolType() if param.id == 'bool' else ValType()
200-
subscope.declared[param.id] = NameInfo(subscope, param.id, tp, is_param=True)
201-
params.append(ParamInfo(param.id, tp))
158+
for tp_node, name_node in fn.params:
159+
if tp_node.id not in PARAM_TYPES:
160+
raise self.err("Unknown parameter type", tp_node.region)
161+
if (name := name_node.id) in subscope.declared:
162+
raise self.err("There is already a parameter of this name",
163+
name_node.region)
164+
tp = BoolType() if tp_node.id == 'bool' else ValType()
165+
subscope.declared[name] = NameInfo(subscope, name, tp, is_param=True)
166+
params.append(ParamInfo(name, tp))
202167
curr_scope.declared[ident] = info = FuncInfo.from_param_info(
203168
curr_scope, ident, params,
204169
ret_type=VoidType(), subscope=subscope)
@@ -226,3 +191,30 @@ def enter_fn_decl(fn: AstDefine):
226191

227192
def err(self, msg: str, region: StrRegion):
228193
return NameResolutionError(msg, region, self.src)
194+
195+
196+
class Typechecker:
197+
def __init__(self, name_resolver: NameResolver):
198+
self.resolver = name_resolver
199+
self.src = self.resolver.src
200+
self.is_ok: bool | None = None
201+
202+
def _init(self):
203+
self.resolver.run()
204+
self.ast = self.resolver.ast
205+
self.top_scope = self.resolver.top_scope
206+
207+
def run(self):
208+
if self.is_ok is None:
209+
return self.is_ok
210+
self._typecheck()
211+
self.is_ok = True
212+
return self.is_ok
213+
214+
def _typecheck(self):
215+
walker = FilteredWalker()
216+
217+
self.ast.walk(walker)
218+
...
219+
220+

test/test_typecheck/test_name_resolve.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from unittest.mock import Mock, patch
22

33
from parser.common import StrRegion
4-
from parser.typecheck.typecheck import NameResolver
4+
from parser.typecheck.typecheck import (
5+
NameResolver, Scope, NameInfo, BoolType, ValType, VoidType,
6+
FuncInfo, ParamInfo)
57
from test.common import CommonTestCase
68

79

@@ -48,6 +50,30 @@ def test_top_scope_attr(self):
4850
self.assertIs(v2, nr.top_scope)
4951
m.assert_called_once() # Still only once
5052

53+
def test_params(self):
54+
src = ('def f1(bool b0, val v0, string s0, number n0) {let L0=s0..v0;};'
55+
'def f2() {}')
56+
sc = Scope()
57+
f1_scope = Scope()
58+
f1_scope.declared = {
59+
'b0': NameInfo(f1_scope, 'b0', BoolType(), is_param=True),
60+
'v0': (v0 := NameInfo(f1_scope, 'v0', ValType(), is_param=True)),
61+
's0': (s0 := NameInfo(f1_scope, 's0', ValType(), is_param=True)),
62+
'n0': NameInfo(f1_scope, 'n0', ValType(), is_param=True),
63+
'L0': NameInfo(f1_scope, 'L0', ValType())
64+
}
65+
f1_scope.used = {'v0': v0, 's0': s0}
66+
sc.declared = {
67+
'f1': FuncInfo.from_param_info(sc, 'f1', [
68+
ParamInfo('b0', BoolType()),
69+
ParamInfo('v0', ValType()),
70+
ParamInfo('s0', ValType()), # val == string == number for now
71+
ParamInfo('n0', ValType()),
72+
], VoidType(), f1_scope),
73+
'f2': FuncInfo.from_param_info(sc, 'f2', [], VoidType(), Scope())
74+
}
75+
self.assertEqual(sc, self.getNameResolver(src).run())
76+
5177

5278
class TestNameResolveErrors(CommonTestCase):
5379
def test_undefined_var(self):

util/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from os import PathLike
44
from typing import TypeVar, Any, overload, Iterable
55

6+
from .recursive_eq import recursive_eq
67
from .simple_process_pool import *
78
from .timeouts import *
89

util/recursive_eq.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Safe recursive equality comparison."""
2+
import functools
3+
4+
5+
def recursive_eq(fn):
6+
"""Must be used as decorator, like reprlib.recursive_repr.
7+
Works by hypothesising that 2 ids are equal. Then, it tries to compare
8+
them. If it encounters one of them again, it checks that the corresponding
9+
value is the hypothesised value. If so, they're equal. If not, they're
10+
unequal."""
11+
hypotheses: dict[int, int] = {} # int <-> int (should be undirected)
12+
13+
@functools.wraps(fn)
14+
def eq(a, b):
15+
if (bid_exp := hypotheses.get(id(a))) is not None:
16+
return bid_exp == id(b)
17+
if (aid_exp := hypotheses.get(id(b))) is not None:
18+
return aid_exp == id(a)
19+
hypotheses[id(a)] = id(b)
20+
hypotheses[id(b)] = id(a)
21+
try:
22+
return fn(a, b) # Will call this function again
23+
finally:
24+
del hypotheses[id(a)]
25+
del hypotheses[id(b)]
26+
return eq

0 commit comments

Comments
 (0)