Skip to content
60 changes: 46 additions & 14 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@
"builtins.memoryview",
}

POISON_KEY: Final = (-1,)


class TooManyUnions(Exception):
"""Indicates that we need to stop splitting unions in an attempt
Expand Down Expand Up @@ -356,7 +358,12 @@ def __init__(

self._arg_infer_context_cache = None

self.overload_stack_depth = 0
self._args_cache: dict[tuple[int, ...], list[Type]] = {}

def reset(self) -> None:
assert self.overload_stack_depth == 0
assert not self._args_cache
self.resolved_type = {}

def visit_name_expr(self, e: NameExpr) -> Type:
Expand Down Expand Up @@ -1613,9 +1620,10 @@ def check_call(
object_type,
)
elif isinstance(callee, Overloaded):
return self.check_overload_call(
callee, args, arg_kinds, arg_names, callable_name, object_type, context
)
with self.overload_context():
return self.check_overload_call(
callee, args, arg_kinds, arg_names, callable_name, object_type, context
)
elif isinstance(callee, AnyType) or not self.chk.in_checked_function():
return self.check_any_type_call(args, callee)
elif isinstance(callee, UnionType):
Expand Down Expand Up @@ -1674,6 +1682,14 @@ def check_call(
else:
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)

@contextmanager
def overload_context(self) -> Iterator[None]:
self.overload_stack_depth += 1
yield
self.overload_stack_depth -= 1
if self.overload_stack_depth == 0:
self._args_cache.clear()

def check_callable_call(
self,
callee: CallableType,
Expand Down Expand Up @@ -1937,6 +1953,17 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
In short, we basically recurse on each argument without considering
in what context the argument was called.
"""
# We can only use this hack locally while checking a single nested overloaded
# call. This saves a lot of rechecking, but is not generally safe. Cache is
# pruned upon leaving the outermost overload.
can_cache = (
self.overload_stack_depth > 0
and POISON_KEY not in self._args_cache
and not any(isinstance(t, TempNode) for t in args)
)
key = tuple(map(id, args))
if can_cache and key in self._args_cache:
return self._args_cache[key]
res: list[Type] = []

for arg in args:
Expand All @@ -1945,6 +1972,8 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
res.append(NoneType())
else:
res.append(arg_type)
if can_cache:
self._args_cache[key] = res
return res

def infer_more_unions_for_recursive_type(self, type_context: Type) -> bool:
Expand Down Expand Up @@ -2917,17 +2946,16 @@ def infer_overload_return_type(

for typ in plausible_targets:
assert self.msg is self.chk.msg
with self.msg.filter_errors() as w:
with self.chk.local_type_map() as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
with self.msg.filter_errors() as w, self.chk.local_type_map() as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
is_match = not w.has_new_errors()
if is_match:
# Return early if possible; otherwise record info, so we can
Expand Down Expand Up @@ -3474,6 +3502,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
return self.strfrm_checker.check_str_interpolation(e.left, e.right)
if isinstance(e.left, StrExpr):
return self.strfrm_checker.check_str_interpolation(e.left, e.right)

left_type = self.accept(e.left)

proper_left_type = get_proper_type(left_type)
Expand Down Expand Up @@ -5401,6 +5430,9 @@ def find_typeddict_context(

def visit_lambda_expr(self, e: LambdaExpr) -> Type:
"""Type check lambda expression."""
if self.overload_stack_depth > 0:
# Poison cache when we encounter lambdas - it isn't safe to cache their types.
self._args_cache[POISON_KEY] = []
self.chk.check_default_args(e, body_is_trivial=False)
inferred_type, type_override = self.infer_lambda_type_using_context(e)
if not inferred_type:
Expand Down