Skip to content

Commit 477246e

Browse files
fix(name-resolve): Fix Scope.__eq__ giving recursion errors
1 parent 69c9d27 commit 477246e

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

parser/typecheck/typecheck.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from dataclasses import dataclass, field
44

5+
from util.recursive_eq import recursive_eq
56
from ..astgen.ast_node import (
67
AstNode, walk_ast, AstIdent, AstDeclNode, AstDefine, VarDeclType,
78
VarDeclScope, FilteredWalker)
@@ -84,6 +85,9 @@ class Scope:
8485
(so type codegen/type-checker knows what each AstIdent refers to)"""
8586

8687

88+
Scope.__eq__ = recursive_eq(Scope.__eq__)
89+
90+
8791
class NameResolutionError(BaseLocatedError):
8892
pass
8993

util/recursive_eq.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Safe recursive equality comparison."""
2+
import functools
3+
4+
5+
def recursive_eq(fn):
6+
"""Must be used as decorator, like reprlib.recursive_repr.
7+
Works by hypothesising that 2 ids are equal. Then, it tries to compare
8+
them. If it encounters one of them again, it checks that the corresponding
9+
value is the hypothesised value. If so, they're equal. If not, they're
10+
unequal."""
11+
hypotheses: dict[int, int] = {} # int <-> int (should be undirected)
12+
13+
@functools.wraps(fn)
14+
def eq(a, b):
15+
if (bid_exp := hypotheses.get(id(a))) is not None:
16+
return bid_exp == id(b)
17+
if (aid_exp := hypotheses.get(id(b))) is not None:
18+
return aid_exp == id(a)
19+
hypotheses[id(a)] = id(b)
20+
hypotheses[id(b)] = id(a)
21+
try:
22+
return fn(a, b) # Will call this function again
23+
finally:
24+
del hypotheses[id(a)]
25+
del hypotheses[id(b)]
26+
return eq

0 commit comments

Comments
 (0)