Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
from collections import defaultdict
from collections.abc import Iterable, Iterator, Mapping, Sequence, Set as AbstractSet
from contextlib import ExitStack, contextmanager
from typing import Callable, Final, Generic, NamedTuple, Optional, TypeVar, Union, cast, overload
from typing import (
Callable,
Final,
Generic,
Literal,
NamedTuple,
Optional,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import TypeAlias as _TypeAlias, TypeGuard

import mypy.checkexpr
Expand Down Expand Up @@ -277,6 +288,26 @@ class PartialTypeScope(NamedTuple):
is_local: bool


class LocalTypeMap:
"""Store inferred types into a temporary type map (returned).

This can be used to perform type checking "experiments" without
affecting exported types (which are used by mypyc).
"""

def __init__(self, chk: TypeChecker) -> None:
self.chk = chk

def __enter__(self) -> dict[Expression, Type]:
temp_type_map: dict[Expression, Type] = {}
self.chk._type_maps.append(temp_type_map)
return temp_type_map

def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> Literal[False]:
self.chk._type_maps.pop()
return False


class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi):
"""Mypy type checker.

Expand Down Expand Up @@ -402,6 +433,7 @@ def __init__(
self.is_typeshed_stub = tree.is_typeshed_file(options)
self.inferred_attribute_types = None
self.allow_constructor_cache = True
self.local_type_map = LocalTypeMap(self)

# If True, process function definitions. If False, don't. This is used
# for processing module top levels in fine-grained incremental mode.
Expand Down Expand Up @@ -4631,7 +4663,7 @@ def check_simple_assignment(
# may cause some perf impact, plus we want to partially preserve
# the old behavior. This helps with various practical examples, see
# e.g. testOptionalTypeNarrowedByGenericCall.
with self.msg.filter_errors() as local_errors, self.local_type_map() as type_map:
with self.msg.filter_errors() as local_errors, self.local_type_map as type_map:
alt_rvalue_type = self.expr_checker.accept(
rvalue, None, always_allow_any=always_allow_any
)
Expand Down Expand Up @@ -7458,18 +7490,6 @@ def lookup_type(self, node: Expression) -> Type:
def store_types(self, d: dict[Expression, Type]) -> None:
self._type_maps[-1].update(d)

@contextmanager
def local_type_map(self) -> Iterator[dict[Expression, Type]]:
"""Store inferred types into a temporary type map (returned).

This can be used to perform type checking "experiments" without
affecting exported types (which are used by mypyc).
"""
temp_type_map: dict[Expression, Type] = {}
self._type_maps.append(temp_type_map)
yield temp_type_map
self._type_maps.pop()

def in_checked_function(self) -> bool:
"""Should we type-check the current function?

Expand Down
15 changes: 6 additions & 9 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def check_typeddict_call_with_kwargs(

# We don't show any errors, just infer types in a generic TypedDict type,
# a custom error message will be given below, if there are errors.
with self.msg.filter_errors(), self.chk.local_type_map():
with self.msg.filter_errors(), self.chk.local_type_map:
orig_ret_type, _ = self.check_callable_call(
infer_callee,
# We use first expression for each key to infer type variables of a generic
Expand Down Expand Up @@ -1440,7 +1440,7 @@ def is_generic_decorator_overload_call(
return None
if not isinstance(get_proper_type(callee_type.ret_type), CallableType):
return None
with self.chk.local_type_map():
with self.chk.local_type_map:
with self.msg.filter_errors():
arg_type = get_proper_type(self.accept(args[0], type_context=None))
if isinstance(arg_type, Overloaded):
Expand Down Expand Up @@ -2920,7 +2920,7 @@ def infer_overload_return_type(
for typ in plausible_targets:
assert self.msg is self.chk.msg
with self.msg.filter_errors() as w:
with self.chk.local_type_map() as m:
with self.chk.local_type_map as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
Expand Down Expand Up @@ -5367,7 +5367,7 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
return self.check_typeddict_literal_in_context(e, typeddict_contexts[0])
# Multiple items union, check if at least one of them matches cleanly.
for typeddict_context in typeddict_contexts:
with self.msg.filter_errors() as err, self.chk.local_type_map() as tmap:
with self.msg.filter_errors() as err, self.chk.local_type_map as tmap:
ret_type = self.check_typeddict_literal_in_context(e, typeddict_context)
if err.has_new_errors():
continue
Expand Down Expand Up @@ -6095,15 +6095,12 @@ def accept(

def accept_maybe_cache(self, node: Expression, type_context: Type | None = None) -> Type:
binder_version = self.chk.binder.version
# Micro-optimization: inline local_type_map() as it is somewhat slow in mypyc.
type_map: dict[Expression, Type] = {}
self.chk._type_maps.append(type_map)
with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as msg:
typ = node.accept(self)
with self.chk.local_type_map as type_map:
typ = node.accept(self)
messages = msg.filtered_errors()
if binder_version == self.chk.binder.version and not self.chk.current_node_deferred:
self.expr_cache[(node, type_context)] = (binder_version, typ, messages, type_map)
self.chk._type_maps.pop()
self.chk.store_types(type_map)
self.msg.add_errors(messages)
return typ
Expand Down
4 changes: 3 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3048,7 +3048,9 @@ def report_missing_module_attribute(
message = (
f'Module "{import_id}" does not explicitly export attribute "{source_id}"'
)
else:
elif not (
self.options.ignore_errors or self.cur_mod_node.path in self.errors.ignored_files
):
alternatives = set(module.names.keys()).difference({source_id})
matches = best_matches(source_id, alternatives, n=3)
if matches:
Expand Down
Loading