Skip to content
88 changes: 65 additions & 23 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@
"builtins.memoryview",
}

POISON_KEY: Final = (-1,)


class TooManyUnions(Exception):
"""Indicates that we need to stop splitting unions in an attempt
Expand Down Expand Up @@ -356,7 +358,12 @@ def __init__(

self._arg_infer_context_cache = None

self.overload_stack_depth = 0
self._args_cache: dict[tuple[int, ...], list[Type]] = {}

def reset(self) -> None:
assert self.overload_stack_depth == 0
assert not self._args_cache
self.resolved_type = {}

def visit_name_expr(self, e: NameExpr) -> Type:
Expand Down Expand Up @@ -1613,9 +1620,10 @@ def check_call(
object_type,
)
elif isinstance(callee, Overloaded):
return self.check_overload_call(
callee, args, arg_kinds, arg_names, callable_name, object_type, context
)
with self.overload_context():
return self.check_overload_call(
callee, args, arg_kinds, arg_names, callable_name, object_type, context
)
elif isinstance(callee, AnyType) or not self.chk.in_checked_function():
return self.check_any_type_call(args, callee)
elif isinstance(callee, UnionType):
Expand Down Expand Up @@ -1678,6 +1686,14 @@ def check_call(
else:
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)

@contextmanager
def overload_context(self) -> Iterator[None]:
self.overload_stack_depth += 1
yield
self.overload_stack_depth -= 1
if self.overload_stack_depth == 0:
self._args_cache.clear()

def check_callable_call(
self,
callee: CallableType,
Expand Down Expand Up @@ -1935,20 +1951,40 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type:
self.msg.unsupported_type_type(item, context)
return AnyType(TypeOfAny.from_error)

def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]:
def infer_arg_types_in_empty_context(
self, args: list[Expression], *, allow_cache: bool
) -> list[Type]:
"""Infer argument expression types in an empty context.

In short, we basically recurse on each argument without considering
in what context the argument was called.
"""
# We can only use this hack locally while checking a single nested overloaded
# call. This saves a lot of rechecking, but is not generally safe. Cache is
# pruned upon leaving the outermost overload.
can_cache = (
allow_cache
and POISON_KEY not in self._args_cache
and not any(isinstance(t, TempNode) for t in args)
)
key = tuple(map(id, args))
if can_cache and key in self._args_cache:
return self._args_cache[key]
res: list[Type] = []

for arg in args:
arg_type = self.accept(arg)
if has_erased_component(arg_type):
res.append(NoneType())
else:
res.append(arg_type)
with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as w:
for arg in args:
arg_type = self.accept(arg)
if has_erased_component(arg_type):
res.append(NoneType())
else:
res.append(arg_type)

if w.has_new_errors():
self.msg.add_errors(w.filtered_errors())
elif can_cache:
# Do not cache if new diagnostics were emitted: they may impact parent overload
self._args_cache[key] = res
return res

def infer_more_unions_for_recursive_type(self, type_context: Type) -> bool:
Expand Down Expand Up @@ -2712,7 +2748,7 @@ def check_overload_call(
"""Checks a call to an overloaded function."""
# Normalize unpacked kwargs before checking the call.
callee = callee.with_unpacked_kwargs()
arg_types = self.infer_arg_types_in_empty_context(args)
arg_types = self.infer_arg_types_in_empty_context(args, allow_cache=True)
# Step 1: Filter call targets to remove ones where the argument counts don't match
plausible_targets = self.plausible_overload_call_targets(
arg_types, arg_kinds, arg_names, callee
Expand Down Expand Up @@ -2921,17 +2957,16 @@ def infer_overload_return_type(

for typ in plausible_targets:
assert self.msg is self.chk.msg
with self.msg.filter_errors() as w:
with self.chk.local_type_map() as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
with self.msg.filter_errors() as w, self.chk.local_type_map() as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
is_match = not w.has_new_errors()
if is_match:
# Return early if possible; otherwise record info, so we can
Expand Down Expand Up @@ -3307,7 +3342,7 @@ def apply_generic_arguments(
)

def check_any_type_call(self, args: list[Expression], callee: Type) -> tuple[Type, Type]:
self.infer_arg_types_in_empty_context(args)
self.infer_arg_types_in_empty_context(args, allow_cache=False)
callee = get_proper_type(callee)
if isinstance(callee, AnyType):
return (
Expand Down Expand Up @@ -3478,6 +3513,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
return self.strfrm_checker.check_str_interpolation(e.left, e.right)
if isinstance(e.left, StrExpr):
return self.strfrm_checker.check_str_interpolation(e.left, e.right)

left_type = self.accept(e.left)

proper_left_type = get_proper_type(left_type)
Expand Down Expand Up @@ -4350,6 +4386,9 @@ def check_list_multiply(self, e: OpExpr) -> Type:
return result

def visit_assignment_expr(self, e: AssignmentExpr) -> Type:
if self.overload_stack_depth > 0:
# Poison cache when we encounter assignments in overloads - they affect the binder.
self._args_cache[POISON_KEY] = []
value = self.accept(e.value)
self.chk.check_assignment(e.target, e.value)
self.chk.check_final(e)
Expand Down Expand Up @@ -5405,6 +5444,9 @@ def find_typeddict_context(

def visit_lambda_expr(self, e: LambdaExpr) -> Type:
"""Type check lambda expression."""
if self.overload_stack_depth > 0:
# Poison cache when we encounter lambdas - it isn't safe to cache their types.
self._args_cache[POISON_KEY] = []
self.chk.check_default_args(e, body_is_trivial=False)
inferred_type, type_override = self.infer_lambda_type_using_context(e)
if not inferred_type:
Expand Down