Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/mypy_primer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
--debug \
--additional-flags="--debug-serialize" \
--output concise \
--show-speed-regression \
| tee diff_${{ matrix.shard-index }}.txt
) || [ $? -eq 1 ]
- if: ${{ matrix.shard-index == 0 }}
Expand Down
4 changes: 2 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def check_first_pass(self) -> None:
Deferred functions will be processed by check_second_pass().
"""
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional):
with state.strict_optional_set(self.options.strict_optional), state.type_checker_set(self):
self.errors.set_file(
self.path, self.tree.fullname, scope=self.tscope, options=self.options
)
Expand Down Expand Up @@ -496,7 +496,7 @@ def check_second_pass(
This goes through deferred nodes, returning True if there were any.
"""
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional):
with state.strict_optional_set(self.options.strict_optional), state.type_checker_set(self):
if not todo and not self.deferred_nodes:
return False
self.errors.set_file(
Expand Down
55 changes: 24 additions & 31 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
is_self: bool = False,
rvalue: Expression | None = None,
suppress_errors: bool = False,
preserve_type_var_ids: bool = False,
) -> None:
self.is_lvalue = is_lvalue
self.is_super = is_super
Expand All @@ -112,6 +113,10 @@ def __init__(
assert is_lvalue
self.rvalue = rvalue
self.suppress_errors = suppress_errors
# This attribute is only used to preserve old protocol member access logic.
# It is needed to avoid infinite recursion in cases involving self-referential
# generic methods, see find_member() for details. Do not use for other purposes!
self.preserve_type_var_ids = preserve_type_var_ids

def named_type(self, name: str) -> Instance:
return self.chk.named_type(name)
Expand Down Expand Up @@ -142,6 +147,7 @@ def copy_modified(
no_deferral=self.no_deferral,
rvalue=self.rvalue,
suppress_errors=self.suppress_errors,
preserve_type_var_ids=self.preserve_type_var_ids,
)
if self_type is not None:
mx.self_type = self_type
Expand Down Expand Up @@ -231,8 +237,6 @@ def analyze_member_access(
def _analyze_member_access(
name: str, typ: Type, mx: MemberContext, override_info: TypeInfo | None = None
) -> Type:
# TODO: This and following functions share some logic with subtypes.find_member;
# consider refactoring.
typ = get_proper_type(typ)
if isinstance(typ, Instance):
return analyze_instance_member_access(name, typ, mx, override_info)
Expand Down Expand Up @@ -355,7 +359,8 @@ def analyze_instance_member_access(
return AnyType(TypeOfAny.special_form)
assert isinstance(method.type, Overloaded)
signature = method.type
signature = freshen_all_functions_type_vars(signature)
if not mx.preserve_type_var_ids:
signature = freshen_all_functions_type_vars(signature)
if not method.is_static:
signature = check_self_arg(
signature, mx.self_type, method.is_class, mx.context, name, mx.msg
Expand Down Expand Up @@ -928,7 +933,8 @@ def analyze_var(
def expand_without_binding(
typ: Type, var: Var, itype: Instance, original_itype: Instance, mx: MemberContext
) -> Type:
typ = freshen_all_functions_type_vars(typ)
if not mx.preserve_type_var_ids:
typ = freshen_all_functions_type_vars(typ)
typ = expand_self_type_if_needed(typ, mx, var, original_itype)
expanded = expand_type_by_instance(typ, itype)
freeze_all_type_vars(expanded)
Expand All @@ -938,7 +944,8 @@ def expand_without_binding(
def expand_and_bind_callable(
functype: FunctionLike, var: Var, itype: Instance, name: str, mx: MemberContext
) -> Type:
functype = freshen_all_functions_type_vars(functype)
if not mx.preserve_type_var_ids:
functype = freshen_all_functions_type_vars(functype)
typ = get_proper_type(expand_self_type(var, functype, mx.original_type))
assert isinstance(typ, FunctionLike)
typ = check_self_arg(typ, mx.self_type, var.is_classmethod, mx.context, name, mx.msg)
Expand Down Expand Up @@ -1033,10 +1040,12 @@ def f(self: S) -> T: ...
return functype
else:
selfarg = get_proper_type(item.arg_types[0])
# This level of erasure matches the one in checker.check_func_def(),
# better keep these two checks consistent.
if subtypes.is_subtype(
# This matches similar special-casing in bind_self(), see more details there.
self_callable = name == "__call__" and isinstance(selfarg, CallableType)
if self_callable or subtypes.is_subtype(
dispatched_arg_type,
# This level of erasure matches the one in checker.check_func_def(),
# better keep these two checks consistent.
erase_typevars(erase_to_bound(selfarg)),
# This is to work around the fact that erased ParamSpec and TypeVarTuple
# callables are not always compatible with non-erased ones both ways.
Expand Down Expand Up @@ -1197,15 +1206,10 @@ def analyze_class_attribute_access(
is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or (
isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_class
)
is_staticmethod = (is_decorated and cast(Decorator, node.node).func.is_static) or (
isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_static
)
t = get_proper_type(t)
if isinstance(t, FunctionLike) and is_classmethod:
t = check_self_arg(t, mx.self_type, False, mx.context, name, mx.msg)
result = add_class_tvars(
t, isuper, is_classmethod, is_staticmethod, mx.self_type, original_vars=original_vars
)
result = add_class_tvars(t, isuper, is_classmethod, mx, original_vars=original_vars)
# __set__ is not called on class objects.
if not mx.is_lvalue:
result = analyze_descriptor_access(result, mx)
Expand Down Expand Up @@ -1337,8 +1341,7 @@ def add_class_tvars(
t: ProperType,
isuper: Instance | None,
is_classmethod: bool,
is_staticmethod: bool,
original_type: Type,
mx: MemberContext,
original_vars: Sequence[TypeVarLikeType] | None = None,
) -> Type:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function does not appear to be a performance bottleneck (at least in self check).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JukkaL If you will have time, could you please check if there is any slowness because of bind_self() and check_self_arg()? Although they are not modified, they may be called much more often now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_self_arg could be more expensive -- it appears to consume an extra ~0.5% of runtime in this PR. We are now spending maybe 2-3% of CPU in it, so it's quite hot, but it already was pretty hot before this PR. This could be noise though.

I didn't see any major change in bind_self when doing self check, though it's pretty hot both before and after, though less hot than check_self_arg.

"""Instantiate type variables during analyze_class_attribute_access,
Expand All @@ -1356,9 +1359,6 @@ class B(A[str]): pass
isuper: Current instance mapped to the superclass where method was defined, this
is usually done by map_instance_to_supertype()
is_classmethod: True if this method is decorated with @classmethod
is_staticmethod: True if this method is decorated with @staticmethod
original_type: The value of the type B in the expression B.foo() or the corresponding
component in case of a union (this is used to bind the self-types)
original_vars: Type variables of the class callable on which the method was accessed
Returns:
Expanded method type with added type variables (when needed).
Expand All @@ -1379,11 +1379,11 @@ class B(A[str]): pass
# (i.e. appear in the return type of the class object on which the method was accessed).
if isinstance(t, CallableType):
tvars = original_vars if original_vars is not None else []
t = freshen_all_functions_type_vars(t)
if not mx.preserve_type_var_ids:
t = freshen_all_functions_type_vars(t)
if is_classmethod:
t = bind_self(t, original_type, is_classmethod=True)
if is_classmethod or is_staticmethod:
assert isuper is not None
t = bind_self(t, mx.self_type, is_classmethod=True)
if isuper is not None:
t = expand_type_by_instance(t, isuper)
freeze_all_type_vars(t)
return t.copy_modified(variables=list(tvars) + list(t.variables))
Expand All @@ -1392,14 +1392,7 @@ class B(A[str]): pass
[
cast(
CallableType,
add_class_tvars(
item,
isuper,
is_classmethod,
is_staticmethod,
original_type,
original_vars=original_vars,
),
add_class_tvars(item, isuper, is_classmethod, mx, original_vars=original_vars),
)
for item in t.items
]
Expand Down
3 changes: 2 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Final, TypeVar, cast, overload

from mypy.nodes import ARG_STAR, FakeInfo, Var
from mypy.state import state
from mypy.types import (
ANY_STRATEGY,
AnyType,
Expand Down Expand Up @@ -544,6 +543,8 @@ def remove_trivial(types: Iterable[Type]) -> list[Type]:
* Remove everything else if there is an `object`
* Remove strict duplicate types
"""
from mypy.state import state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This nested imports also causes a small performance regression (maybe 0.1% to 0.2%).


removed_none = False
new_types = []
all_types = set()
Expand Down
9 changes: 7 additions & 2 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2220,8 +2220,13 @@ def report_protocol_problems(
exp = get_proper_type(exp)
got = get_proper_type(got)
setter_suffix = " setter type" if is_lvalue else ""
if not isinstance(exp, (CallableType, Overloaded)) or not isinstance(
got, (CallableType, Overloaded)
if (
not isinstance(exp, (CallableType, Overloaded))
or not isinstance(got, (CallableType, Overloaded))
# If expected type is a type object, it means it is a nested class.
# Showing constructor signature in errors would be confusing in this case,
# since we don't check the signature, only subclassing of type objects.
or exp.is_type_obj()
):
self.note(
"{}: expected{} {}, got {}".format(
Expand Down
8 changes: 5 additions & 3 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,13 @@ class C: pass
from __future__ import annotations

from abc import abstractmethod
from typing import Any, Callable, NamedTuple, TypeVar
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar

from mypy_extensions import mypyc_attr, trait

from mypy.errorcodes import ErrorCode
from mypy.lookup import lookup_fully_qualified
from mypy.message_registry import ErrorMessage
from mypy.messages import MessageBuilder
from mypy.nodes import (
ArgKind,
CallExpr,
Expand All @@ -138,7 +137,6 @@ class C: pass
TypeInfo,
)
from mypy.options import Options
from mypy.tvar_scope import TypeVarLikeScope
from mypy.types import (
CallableType,
FunctionLike,
Expand All @@ -149,6 +147,10 @@ class C: pass
UnboundType,
)

if TYPE_CHECKING:
from mypy.messages import MessageBuilder
from mypy.tvar_scope import TypeVarLikeScope


@trait
class TypeAnalyzerPluginInterface:
Expand Down
20 changes: 16 additions & 4 deletions mypy/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
from contextlib import contextmanager
from typing import Final

from mypy.checker_shared import TypeCheckerSharedApi

# These are global mutable state. Don't add anything here unless there's a very
# good reason.


class StrictOptionalState:
class SubtypeState:
# Wrap this in a class since it's faster that using a module-level attribute.

def __init__(self, strict_optional: bool) -> None:
# Value varies by file being processed
def __init__(self, strict_optional: bool, type_checker: TypeCheckerSharedApi | None) -> None:
# Values vary by file being processed
self.strict_optional = strict_optional
self.type_checker = type_checker

@contextmanager
def strict_optional_set(self, value: bool) -> Iterator[None]:
Expand All @@ -24,6 +27,15 @@ def strict_optional_set(self, value: bool) -> Iterator[None]:
finally:
self.strict_optional = saved

@contextmanager
def type_checker_set(self, value: TypeCheckerSharedApi) -> Iterator[None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dependency on TypeCheckerSharedApi probably makes various import cycles worse, and I assume this why there are some additional nested imports. Defining type_checker_set in a new module would improve things, right? Splitting this module seems better than making import cycles bigger, and it should also reduce the performance regression.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining type_checker_set in a new module would improve things, right?

Yeah, I think this should be better. I will play with this.

saved = self.type_checker
self.type_checker = value
try:
yield
finally:
self.type_checker = saved


state: Final = StrictOptionalState(strict_optional=True)
state: Final = SubtypeState(strict_optional=True, type_checker=None)
find_occurrences: tuple[str, str] | None = None
Loading