Skip to content

Commit 325ddcd

Browse files
committed
Micro-optimization: Make ArgKind a regular class instead of enum
Mypyc doesn't generate very efficient code for enums, so switch to a regular class. We can later revert the change if/when we can improve enum support in mypyc. Operations related to ArgKind were pretty prominent in the op trace log (#19457). By itself this improves performance by ~1.7%, based on `perf_compare.py`, which is significant: ``` master 4.168s (0.0%) | stdev 0.037s HEAD 4.098s (-1.7%) | stdev 0.028s ``` This is a part of a set of micro-optimizations that improve self check performance by ~5.5%.
1 parent 02a472a commit 325ddcd

File tree

8 files changed

+123
-59
lines changed

8 files changed

+123
-59
lines changed

mypy/checkexpr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from mypy.messages import MessageBuilder, format_type
3535
from mypy.nodes import (
3636
ARG_NAMED,
37+
ARG_NAMED_OPT,
3738
ARG_POS,
3839
ARG_STAR,
3940
ARG_STAR2,
@@ -1000,7 +1001,7 @@ def typeddict_callable_from_context(
10001001
return CallableType(
10011002
list(callee.items.values()),
10021003
[
1003-
ArgKind.ARG_NAMED if name in callee.required_keys else ArgKind.ARG_NAMED_OPT
1004+
ARG_NAMED if name in callee.required_keys else ARG_NAMED_OPT
10041005
for name in callee.items
10051006
],
10061007
list(callee.items.keys()),
@@ -1074,7 +1075,7 @@ def check_typeddict_call_with_kwargs(
10741075
# TypedDict. This is a bit arbitrary, but in most cases will work better than
10751076
# trying to infer a union or a join.
10761077
[args[0] for args in kwargs.values()],
1077-
[ArgKind.ARG_NAMED] * len(kwargs),
1078+
[ARG_NAMED] * len(kwargs),
10781079
context,
10791080
list(kwargs.keys()),
10801081
None,

mypy/nodes.py

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,18 @@
66
from abc import abstractmethod
77
from collections import defaultdict
88
from collections.abc import Iterator, Sequence
9-
from enum import Enum, unique
10-
from typing import TYPE_CHECKING, Any, Callable, Final, Optional, TypeVar, Union, cast
9+
from typing import (
10+
TYPE_CHECKING,
11+
Any,
12+
Callable,
13+
ClassVar,
14+
Final,
15+
Optional,
16+
TypeVar,
17+
Union,
18+
cast,
19+
final,
20+
)
1121
from typing_extensions import TypeAlias as _TypeAlias, TypeGuard
1222

1323
from mypy_extensions import trait
@@ -873,7 +883,7 @@ def serialize(self) -> JsonDict:
873883
"name": self._name,
874884
"fullname": self._fullname,
875885
"arg_names": self.arg_names,
876-
"arg_kinds": [int(x.value) for x in self.arg_kinds],
886+
"arg_kinds": [x.value for x in self.arg_kinds],
877887
"type": None if self.type is None else self.type.serialize(),
878888
"flags": get_flags(self, FUNCDEF_FLAGS),
879889
"abstract_status": self.abstract_status,
@@ -904,7 +914,7 @@ def deserialize(cls, data: JsonDict) -> FuncDef:
904914
set_flags(ret, data["flags"])
905915
# NOTE: ret.info is set in the fixup phase.
906916
ret.arg_names = data["arg_names"]
907-
ret.arg_kinds = [ArgKind(x) for x in data["arg_kinds"]]
917+
ret.arg_kinds = [ArgKind.by_value(x) for x in data["arg_kinds"]]
908918
ret.abstract_status = data["abstract_status"]
909919
ret.dataclass_transform_spec = (
910920
DataclassTransformSpec.deserialize(data["dataclass_transform_spec"])
@@ -1963,21 +1973,35 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
19631973
return visitor.visit_member_expr(self)
19641974

19651975

1966-
# Kinds of arguments
1967-
@unique
1968-
class ArgKind(Enum):
1969-
# Positional argument
1970-
ARG_POS = 0
1971-
# Positional, optional argument (functions only, not calls)
1972-
ARG_OPT = 1
1973-
# *arg argument
1974-
ARG_STAR = 2
1975-
# Keyword argument x=y in call, or keyword-only function arg
1976-
ARG_NAMED = 3
1977-
# **arg argument
1978-
ARG_STAR2 = 4
1979-
# In an argument list, keyword-only and also optional
1980-
ARG_NAMED_OPT = 5
1976+
@final
1977+
class ArgKind:
1978+
"""Kinds of arguments.
1979+
1980+
NOTE: This isn't an enum due to mypyc performance limitations.
1981+
"""
1982+
1983+
_sealed: ClassVar[bool] = False # Hack to ensure enum-like behavior
1984+
1985+
def __init__(self, name: str, value: int) -> None:
1986+
assert not ArgKind._sealed
1987+
self.name: Final = name
1988+
self.value: Final = value
1989+
1990+
@staticmethod
1991+
def by_value(value: int) -> ArgKind:
1992+
if value == ARG_POS.value:
1993+
return ARG_POS
1994+
elif value == ARG_OPT.value:
1995+
return ARG_OPT
1996+
elif value == ARG_STAR.value:
1997+
return ARG_STAR
1998+
elif value == ARG_NAMED.value:
1999+
return ARG_NAMED
2000+
elif value == ARG_STAR2.value:
2001+
return ARG_STAR2
2002+
else:
2003+
assert value == ARG_NAMED_OPT.value
2004+
return ARG_NAMED_OPT
19812005

19822006
def is_positional(self, star: bool = False) -> bool:
19832007
return self == ARG_POS or self == ARG_OPT or (star and self == ARG_STAR)
@@ -1995,12 +2019,29 @@ def is_star(self) -> bool:
19952019
return self == ARG_STAR or self == ARG_STAR2
19962020

19972021

1998-
ARG_POS: Final = ArgKind.ARG_POS
1999-
ARG_OPT: Final = ArgKind.ARG_OPT
2000-
ARG_STAR: Final = ArgKind.ARG_STAR
2001-
ARG_NAMED: Final = ArgKind.ARG_NAMED
2002-
ARG_STAR2: Final = ArgKind.ARG_STAR2
2003-
ARG_NAMED_OPT: Final = ArgKind.ARG_NAMED_OPT
2022+
# Positional argument
2023+
ARG_POS: Final = ArgKind("ARG_POS", 0)
2024+
# Positional, optional argument (functions only, not calls)
2025+
ARG_OPT: Final = ArgKind("ARG_OPT", 1)
2026+
# *arg argument
2027+
ARG_STAR: Final = ArgKind("ARG_STAR", 2)
2028+
# Keyword argument x=y in call, or keyword-only function arg
2029+
ARG_NAMED: Final = ArgKind("ARG_NAMED", 3)
2030+
# **arg argument
2031+
ARG_STAR2: Final = ArgKind("ARG_STAR2", 4)
2032+
# In an argument list, keyword-only and also optional
2033+
ARG_NAMED_OPT: Final = ArgKind("ARG_NAMED_OPT", 5)
2034+
2035+
ArgKind._sealed = True # Make sure no new ArgKinds can be created
2036+
2037+
ALL_ARG_KINDS: Final[tuple[ArgKind, ...]] = (
2038+
ARG_POS,
2039+
ARG_OPT,
2040+
ARG_STAR,
2041+
ARG_NAMED,
2042+
ARG_STAR2,
2043+
ARG_NAMED_OPT,
2044+
)
20042045

20052046

20062047
class CallExpr(Expression):

mypy/plugins/functools.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
from mypy.argmap import map_actuals_to_formals
1111
from mypy.erasetype import erase_typevars
1212
from mypy.nodes import (
13+
ARG_NAMED,
14+
ARG_NAMED_OPT,
15+
ARG_OPT,
1316
ARG_POS,
17+
ARG_STAR,
1418
ARG_STAR2,
1519
SYMBOL_FUNCBASE_TYPES,
16-
ArgKind,
1720
Argument,
1821
CallExpr,
1922
NameExpr,
@@ -217,11 +220,7 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
217220
# special_sig="partial" allows omission of args/kwargs typed with ParamSpec
218221
defaulted = fn_type.copy_modified(
219222
arg_kinds=[
220-
(
221-
ArgKind.ARG_OPT
222-
if k == ArgKind.ARG_POS
223-
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
224-
)
223+
(ARG_OPT if k == ARG_POS else (ARG_NAMED_OPT if k == ARG_NAMED else k))
225224
for k in fn_type.arg_kinds
226225
],
227226
ret_type=ret_type,
@@ -284,19 +283,19 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
284283
# true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple
285284
arg_type = fn_type.arg_types[i]
286285

287-
if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
286+
if not actuals or fn_type.arg_kinds[i] in (ARG_STAR, ARG_STAR2):
288287
partial_kinds.append(fn_type.arg_kinds[i])
289288
partial_types.append(arg_type)
290289
partial_names.append(fn_type.arg_names[i])
291290
else:
292291
assert actuals
293-
if any(actual_arg_kinds[j] in (ArgKind.ARG_POS, ArgKind.ARG_STAR) for j in actuals):
292+
if any(actual_arg_kinds[j] in (ARG_POS, ARG_STAR) for j in actuals):
294293
# Don't add params for arguments passed positionally
295294
continue
296295
# Add defaulted params for arguments passed via keyword
297296
kind = actual_arg_kinds[actuals[0]]
298-
if kind == ArgKind.ARG_NAMED or kind == ArgKind.ARG_STAR2:
299-
kind = ArgKind.ARG_NAMED_OPT
297+
if kind == ARG_NAMED or kind == ARG_STAR2:
298+
kind = ARG_NAMED_OPT
300299
partial_kinds.append(kind)
301300
partial_types.append(arg_type)
302301
partial_names.append(fn_type.arg_names[i])
@@ -322,9 +321,9 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
322321
if partially_applied.param_spec():
323322
assert ret.extra_attrs is not None # copy_with_extra_attr above ensures this
324323
attrs = ret.extra_attrs.copy()
325-
if ArgKind.ARG_STAR in actual_arg_kinds:
324+
if ARG_STAR in actual_arg_kinds:
326325
attrs.immutable.add("__mypy_partial_paramspec_args_bound")
327-
if ArgKind.ARG_STAR2 in actual_arg_kinds:
326+
if ARG_STAR2 in actual_arg_kinds:
328327
attrs.immutable.add("__mypy_partial_paramspec_kwargs_bound")
329328
ret.extra_attrs = attrs
330329
return ret

mypy/semanal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,7 @@ def analyze_func_def(self, defn: FuncDef) -> None:
10411041
self.pop_type_args(defn.type_args)
10421042

10431043
def remove_unpack_kwargs(self, defn: FuncDef, typ: CallableType) -> CallableType:
1044-
if not typ.arg_kinds or typ.arg_kinds[-1] is not ArgKind.ARG_STAR2:
1044+
if not typ.arg_kinds or typ.arg_kinds[-1] is not ARG_STAR2:
10451045
return typ
10461046
last_type = typ.arg_types[-1]
10471047
if not isinstance(last_type, UnpackType):

mypy/types.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import mypy.nodes
2323
from mypy.bogus_type import Bogus
2424
from mypy.nodes import (
25+
ARG_NAMED,
26+
ARG_NAMED_OPT,
2527
ARG_POS,
2628
ARG_STAR,
2729
ARG_STAR2,
@@ -1782,7 +1784,7 @@ def serialize(self) -> JsonDict:
17821784
return {
17831785
".class": "Parameters",
17841786
"arg_types": [t.serialize() for t in self.arg_types],
1785-
"arg_kinds": [int(x.value) for x in self.arg_kinds],
1787+
"arg_kinds": [x.value for x in self.arg_kinds],
17861788
"arg_names": self.arg_names,
17871789
"variables": [tv.serialize() for tv in self.variables],
17881790
"imprecise_arg_kinds": self.imprecise_arg_kinds,
@@ -1793,7 +1795,7 @@ def deserialize(cls, data: JsonDict) -> Parameters:
17931795
assert data[".class"] == "Parameters"
17941796
return Parameters(
17951797
[deserialize_type(t) for t in data["arg_types"]],
1796-
[ArgKind(x) for x in data["arg_kinds"]],
1798+
[ArgKind.by_value(x) for x in data["arg_kinds"]],
17971799
data["arg_names"],
17981800
variables=[cast(TypeVarLikeType, deserialize_type(v)) for v in data["variables"]],
17991801
imprecise_arg_kinds=data["imprecise_arg_kinds"],
@@ -2162,7 +2164,7 @@ def with_unpacked_kwargs(self) -> NormalizedCallableType:
21622164
last_type = get_proper_type(self.arg_types[-1])
21632165
assert isinstance(last_type, TypedDictType)
21642166
extra_kinds = [
2165-
ArgKind.ARG_NAMED if name in last_type.required_keys else ArgKind.ARG_NAMED_OPT
2167+
ARG_NAMED if name in last_type.required_keys else ARG_NAMED_OPT
21662168
for name in last_type.items
21672169
]
21682170
new_arg_kinds = self.arg_kinds[:-1] + extra_kinds
@@ -2283,7 +2285,7 @@ def serialize(self) -> JsonDict:
22832285
return {
22842286
".class": "CallableType",
22852287
"arg_types": [t.serialize() for t in self.arg_types],
2286-
"arg_kinds": [int(x.value) for x in self.arg_kinds],
2288+
"arg_kinds": [x.value for x in self.arg_kinds],
22872289
"arg_names": self.arg_names,
22882290
"ret_type": self.ret_type.serialize(),
22892291
"fallback": self.fallback.serialize(),
@@ -2307,7 +2309,7 @@ def deserialize(cls, data: JsonDict) -> CallableType:
23072309
# TODO: Set definition to the containing SymbolNode?
23082310
return CallableType(
23092311
[deserialize_type(t) for t in data["arg_types"]],
2310-
[ArgKind(x) for x in data["arg_kinds"]],
2312+
[ArgKind.by_value(x) for x in data["arg_kinds"]],
23112313
data["arg_names"],
23122314
deserialize_type(data["ret_type"]),
23132315
Instance.deserialize(data["fallback"]),

mypyc/codegen/emitwrapper.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,16 @@
1414

1515
from collections.abc import Sequence
1616

17-
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, ArgKind
17+
from mypy.nodes import (
18+
ALL_ARG_KINDS,
19+
ARG_NAMED,
20+
ARG_NAMED_OPT,
21+
ARG_OPT,
22+
ARG_POS,
23+
ARG_STAR,
24+
ARG_STAR2,
25+
ArgKind,
26+
)
1827
from mypy.operators import op_methods_to_symbols, reverse_op_method_names, reverse_op_methods
1928
from mypyc.codegen.emit import AssignHandler, Emitter, ErrorHandler, GotoHandler, ReturnHandler
2029
from mypyc.common import (
@@ -88,7 +97,7 @@ def generate_traceback_code(
8897

8998
def make_arg_groups(args: list[RuntimeArg]) -> dict[ArgKind, list[RuntimeArg]]:
9099
"""Group arguments by kind."""
91-
return {k: [arg for arg in args if arg.kind == k] for k in ArgKind}
100+
return {k: [arg for arg in args if arg.kind == k] for k in ALL_ARG_KINDS}
92101

93102

94103
def reorder_arg_groups(groups: dict[ArgKind, list[RuntimeArg]]) -> list[RuntimeArg]:

mypyc/ir/func_ir.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,17 @@
66
from collections.abc import Sequence
77
from typing import Final
88

9-
from mypy.nodes import ARG_POS, ArgKind, Block, FuncDef
9+
from mypy.nodes import (
10+
ARG_NAMED,
11+
ARG_NAMED_OPT,
12+
ARG_OPT,
13+
ARG_POS,
14+
ARG_STAR,
15+
ARG_STAR2,
16+
ArgKind,
17+
Block,
18+
FuncDef,
19+
)
1020
from mypyc.common import BITMAP_BITS, JsonDict, bitmap_name, get_id_from_name, short_id_from_name
1121
from mypyc.ir.ops import (
1222
Assign,
@@ -60,7 +70,7 @@ def serialize(self) -> JsonDict:
6070
return {
6171
"name": self.name,
6272
"type": self.type.serialize(),
63-
"kind": int(self.kind.value),
73+
"kind": self.kind.value,
6474
"pos_only": self.pos_only,
6575
}
6676

@@ -69,7 +79,7 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> RuntimeArg:
6979
return RuntimeArg(
7080
data["name"],
7181
deserialize_type(data["type"], ctx),
72-
ArgKind(data["kind"]),
82+
ArgKind.by_value(data["kind"]),
7383
data["pos_only"],
7484
)
7585

@@ -394,12 +404,12 @@ def all_values_full(args: list[Register], blocks: list[BasicBlock]) -> list[Valu
394404

395405

396406
_ARG_KIND_TO_INSPECT: Final = {
397-
ArgKind.ARG_POS: inspect.Parameter.POSITIONAL_OR_KEYWORD,
398-
ArgKind.ARG_OPT: inspect.Parameter.POSITIONAL_OR_KEYWORD,
399-
ArgKind.ARG_STAR: inspect.Parameter.VAR_POSITIONAL,
400-
ArgKind.ARG_NAMED: inspect.Parameter.KEYWORD_ONLY,
401-
ArgKind.ARG_STAR2: inspect.Parameter.VAR_KEYWORD,
402-
ArgKind.ARG_NAMED_OPT: inspect.Parameter.KEYWORD_ONLY,
407+
ARG_POS: inspect.Parameter.POSITIONAL_OR_KEYWORD,
408+
ARG_OPT: inspect.Parameter.POSITIONAL_OR_KEYWORD,
409+
ARG_STAR: inspect.Parameter.VAR_POSITIONAL,
410+
ARG_NAMED: inspect.Parameter.KEYWORD_ONLY,
411+
ARG_STAR2: inspect.Parameter.VAR_KEYWORD,
412+
ARG_NAMED_OPT: inspect.Parameter.KEYWORD_ONLY,
403413
}
404414

405415
# Sentinel indicating a value that cannot be represented in a text signature.
@@ -418,7 +428,7 @@ def get_text_signature(fn: FuncIR, *, bound: bool = False) -> str | None:
418428
# currently sees 'self' as being positional-or-keyword and '__x' as positional-only.
419429
pos_only_idx = -1
420430
for idx, arg in enumerate(sig.args):
421-
if arg.pos_only and arg.kind in (ArgKind.ARG_POS, ArgKind.ARG_OPT):
431+
if arg.pos_only and arg.kind in (ARG_POS, ARG_OPT):
422432
pos_only_idx = idx
423433
for idx, arg in enumerate(sig.args):
424434
if arg.name.startswith(("__bitmap", "__mypyc")):

mypyc/irbuild/function.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from typing import NamedTuple
1818

1919
from mypy.nodes import (
20+
ARG_OPT,
21+
ARG_POS,
2022
ArgKind,
2123
ClassDef,
2224
Decorator,
@@ -907,7 +909,7 @@ def generate_dispatch_glue_native_function(
907909
decl = builder.mapper.func_to_decl[fitem]
908910
arg_info = get_args(builder, decl.sig.args, line)
909911
args = [callable_class] + arg_info.args
910-
arg_kinds = [ArgKind.ARG_POS] + arg_info.arg_kinds
912+
arg_kinds = [ARG_POS] + arg_info.arg_kinds
911913
arg_names = arg_info.arg_names
912914
arg_names.insert(0, "self")
913915
ret_val = builder.builder.call(callable_class_decl, args, arg_kinds, arg_names, line)
@@ -935,7 +937,7 @@ def add_register_method_to_callable_class(builder: IRBuilder, fn_info: FuncInfo)
935937
line = -1
936938
with builder.enter_method(fn_info.callable_class.ir, "register", object_rprimitive):
937939
cls_arg = builder.add_argument("cls", object_rprimitive)
938-
func_arg = builder.add_argument("func", object_rprimitive, ArgKind.ARG_OPT)
940+
func_arg = builder.add_argument("func", object_rprimitive, ARG_OPT)
939941
ret_val = builder.call_c(register_function, [builder.self(), cls_arg, func_arg], line)
940942
builder.add(Return(ret_val, line))
941943

0 commit comments

Comments
 (0)