@@ -360,6 +360,8 @@ def __init__(
360360 self ._args_cache : dict [tuple [int , ...], list [Type ]] = {}
361361
362362 def reset (self ) -> None :
363+ assert self .overload_stack_depth == 0
364+ assert not self ._args_cache
363365 self .resolved_type = {}
364366
365367 def visit_name_expr (self , e : NameExpr ) -> Type :
@@ -1682,7 +1684,7 @@ def check_call(
16821684 def overload_context (self ) -> Iterator [None ]:
16831685 self .overload_stack_depth += 1
16841686 yield
1685- self .overload_stack_depth - = 1
1687+ self .overload_stack_depth + = 1
16861688 if self .overload_stack_depth == 0 :
16871689 self ._args_cache .clear ()
16881690
@@ -1949,7 +1951,12 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
19491951 In short, we basically recurse on each argument without considering
19501952 in what context the argument was called.
19511953 """
1952- can_cache = not any (isinstance (t , TempNode ) for t in args )
1954+ # We can only use this hack locally while checking a single nested overloaded
1955+ # call. This saves a lot of rechecking, but is not generally safe. Cache is
1956+ # pruned upon leaving the outermost overload.
1957+ can_cache = self .overload_stack_depth > 0 and not any (
1958+ isinstance (t , TempNode ) for t in args
1959+ )
19531960 key = tuple (map (id , args ))
19541961 if can_cache and key in self ._args_cache :
19551962 return self ._args_cache [key ]
0 commit comments