Skip to content

Commit b31cc70

Browse files
committed
Assorted niche optimizations
1 parent 5051a48 commit b31cc70

File tree

4 files changed

+30
-19
lines changed

4 files changed

+30
-19
lines changed

mypy/nodes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@ def deserialize(cls, data: JsonDict) -> FuncDef:
915915
# NOTE: ret.info is set in the fixup phase.
916916
ret.arg_names = data["arg_names"]
917917
ret.original_first_arg = data.get("original_first_arg")
918-
ret.arg_kinds = [ArgKind(x) for x in data["arg_kinds"]]
918+
ret.arg_kinds = [ARG_KINDS[x] for x in data["arg_kinds"]]
919919
ret.abstract_status = data["abstract_status"]
920920
ret.dataclass_transform_spec = (
921921
DataclassTransformSpec.deserialize(data["dataclass_transform_spec"])
@@ -2013,6 +2013,8 @@ def is_star(self) -> bool:
20132013
ARG_STAR2: Final = ArgKind.ARG_STAR2
20142014
ARG_NAMED_OPT: Final = ArgKind.ARG_NAMED_OPT
20152015

2016+
ARG_KINDS: Final = (ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_STAR2, ARG_NAMED_OPT)
2017+
20162018

20172019
class CallExpr(Expression):
20182020
"""Call expression.
@@ -3488,6 +3490,8 @@ def update_tuple_type(self, typ: mypy.types.TupleType) -> None:
34883490
self.special_alias = alias
34893491
else:
34903492
self.special_alias.target = alias.target
3493+
# Invalidate recursive status cache in case it was previously set.
3494+
self.special_alias._is_recursive = None
34913495

34923496
def update_typeddict_type(self, typ: mypy.types.TypedDictType) -> None:
34933497
"""Update typeddict_type and special_alias as needed."""
@@ -3497,6 +3501,8 @@ def update_typeddict_type(self, typ: mypy.types.TypedDictType) -> None:
34973501
self.special_alias = alias
34983502
else:
34993503
self.special_alias.target = alias.target
3504+
# Invalidate recursive status cache in case it was previously set.
3505+
self.special_alias._is_recursive = None
35003506

35013507
def __str__(self) -> str:
35023508
"""Return a string representation of the type.

mypy/semanal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5633,6 +5633,8 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
56335633
existing.node.target = res
56345634
existing.node.alias_tvars = alias_tvars
56355635
updated = True
5636+
# Invalidate recursive status cache in case it was previously set.
5637+
existing.node._is_recursive = None
56365638
else:
56375639
# Otherwise just replace existing placeholder with type alias.
56385640
existing.node = alias_node

mypy/typeops.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
flatten_nested_unions,
6464
get_proper_type,
6565
get_proper_types,
66+
remove_dups,
6667
)
6768
from mypy.typetraverser import TypeTraverserVisitor
6869
from mypy.typevars import fill_typevars
@@ -995,7 +996,7 @@ def is_singleton_type(typ: Type) -> bool:
995996
return typ.is_singleton_type()
996997

997998

998-
def try_expanding_sum_type_to_union(typ: Type, target_fullname: str) -> ProperType:
999+
def try_expanding_sum_type_to_union(typ: Type, target_fullname: str) -> Type:
9991000
"""Attempts to recursively expand any enum Instances with the given target_fullname
10001001
into a Union of all of its component LiteralTypes.
10011002
@@ -1017,21 +1018,22 @@ class Status(Enum):
10171018
typ = get_proper_type(typ)
10181019

10191020
if isinstance(typ, UnionType):
1021+
# Non-empty enums cannot subclass each other so simply removing duplicates is enough.
10201022
items = [
1021-
try_expanding_sum_type_to_union(item, target_fullname) for item in typ.relevant_items()
1023+
try_expanding_sum_type_to_union(item, target_fullname)
1024+
for item in remove_dups(flatten_nested_unions(typ.relevant_items()))
10221025
]
1023-
return make_simplified_union(items, contract_literals=False)
1026+
return UnionType.make_union(items)
10241027

10251028
if isinstance(typ, Instance) and typ.type.fullname == target_fullname:
10261029
if typ.type.fullname == "builtins.bool":
1027-
items = [LiteralType(True, typ), LiteralType(False, typ)]
1028-
return make_simplified_union(items, contract_literals=False)
1030+
return UnionType([LiteralType(True, typ), LiteralType(False, typ)])
10291031

10301032
if typ.type.is_enum:
10311033
items = [LiteralType(name, typ) for name in typ.type.enum_members]
10321034
if not items:
10331035
return typ
1034-
return make_simplified_union(items, contract_literals=False)
1036+
return UnionType.make_union(items)
10351037

10361038
return typ
10371039

mypy/types.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import mypy.nodes
2323
from mypy.bogus_type import Bogus
24-
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, INVARIANT, ArgKind, FakeInfo, SymbolNode
24+
from mypy.nodes import ARG_KINDS, ARG_POS, ARG_STAR, ARG_STAR2, INVARIANT, ArgKind, SymbolNode
2525
from mypy.options import Options
2626
from mypy.state import state
2727
from mypy.util import IdMapper
@@ -538,6 +538,10 @@ def __repr__(self) -> str:
538538
return self.raw_id.__repr__()
539539

540540
def __eq__(self, other: object) -> bool:
541+
# Although this call is not expensive (like UnionType or TypedDictType),
542+
# most of the time we get the same object here, so add a fast path.
543+
if self is other:
544+
return True
541545
return (
542546
isinstance(other, TypeVarId)
543547
and self.raw_id == other.raw_id
@@ -1780,7 +1784,9 @@ def deserialize(cls, data: JsonDict) -> Parameters:
17801784
assert data[".class"] == "Parameters"
17811785
return Parameters(
17821786
[deserialize_type(t) for t in data["arg_types"]],
1783-
[ArgKind(x) for x in data["arg_kinds"]],
1787+
# This is a micro-optimization until mypyc gets dedicated enum support. Otherwise,
1788+
# we would spend ~20% of types deserialization time in Enum.__call__().
1789+
[ARG_KINDS[x] for x in data["arg_kinds"]],
17841790
data["arg_names"],
17851791
variables=[cast(TypeVarLikeType, deserialize_type(v)) for v in data["variables"]],
17861792
imprecise_arg_kinds=data["imprecise_arg_kinds"],
@@ -1797,7 +1803,7 @@ def __hash__(self) -> int:
17971803
)
17981804

17991805
def __eq__(self, other: object) -> bool:
1800-
if isinstance(other, (Parameters, CallableType)):
1806+
if isinstance(other, Parameters):
18011807
return (
18021808
self.arg_types == other.arg_types
18031809
and self.arg_names == other.arg_names
@@ -2210,15 +2216,9 @@ def with_normalized_var_args(self) -> Self:
22102216
)
22112217

22122218
def __hash__(self) -> int:
2213-
# self.is_type_obj() will fail if self.fallback.type is a FakeInfo
2214-
if isinstance(self.fallback.type, FakeInfo):
2215-
is_type_obj = 2
2216-
else:
2217-
is_type_obj = self.is_type_obj()
22182219
return hash(
22192220
(
22202221
self.ret_type,
2221-
is_type_obj,
22222222
self.is_ellipsis_args,
22232223
self.name,
22242224
tuple(self.arg_types),
@@ -2236,7 +2236,6 @@ def __eq__(self, other: object) -> bool:
22362236
and self.arg_names == other.arg_names
22372237
and self.arg_kinds == other.arg_kinds
22382238
and self.name == other.name
2239-
and self.is_type_obj() == other.is_type_obj()
22402239
and self.is_ellipsis_args == other.is_ellipsis_args
22412240
and self.type_guard == other.type_guard
22422241
and self.type_is == other.type_is
@@ -2271,10 +2270,10 @@ def serialize(self) -> JsonDict:
22712270
@classmethod
22722271
def deserialize(cls, data: JsonDict) -> CallableType:
22732272
assert data[".class"] == "CallableType"
2274-
# TODO: Set definition to the containing SymbolNode?
2273+
# The .definition link is set in fixup.py.
22752274
return CallableType(
22762275
[deserialize_type(t) for t in data["arg_types"]],
2277-
[ArgKind(x) for x in data["arg_kinds"]],
2276+
[ARG_KINDS[x] for x in data["arg_kinds"]],
22782277
data["arg_names"],
22792278
deserialize_type(data["ret_type"]),
22802279
Instance.deserialize(data["fallback"]),
@@ -2931,6 +2930,8 @@ def __hash__(self) -> int:
29312930
def __eq__(self, other: object) -> bool:
29322931
if not isinstance(other, UnionType):
29332932
return NotImplemented
2933+
if self is other:
2934+
return True
29342935
return frozenset(self.items) == frozenset(other.items)
29352936

29362937
@overload

0 commit comments

Comments
 (0)