Skip to content

Commit 982853d

Browse files
authored
Special-case enum method calls (#19634)
Improves mypyc/mypyc#1121, this gives a bit above 1% on mypy self-check. This only adds support for regular and overloaded methods without decorators (class/static methods and properties stay slow). When working on this I considered (and actually tried) four options: * Make enums extension classes, then many methods will use fast calls ~automatically (we will just need to set a final flag). This just didn't work, in the sense no segfaults, but it looks like we don't call `__prepare__()`, or don't call it at the right moment. Or maybe I just didn't try hard enough. In general, for some reason this feels risky. * Use existing `CPyDef`s for (non-extension) enum methods, but since they have an extra argument, `__mypyc_self__`, we can supply `NULL` there, since we know it is unused. This is actually easy and it works, but IMO it is ultra-ugly, so I decided to not do it. * Write a separate `CPyDef` without `__mypy_self__`, use it for direct calls, and make existing callable classes `CPyDef`s one-line functions that simply call the first one. This is possible, but quite complicated, and I am not sure it is easy to generalize (e.g. on classmethods). * Finally, the way I do this is to simply generate a second method, that is almost a copy of the original one. This involves a bit of code duplication (in C), but the benefit is that it is conceptually simple, and easily extendable. We can cover more special cases on as-needed basis.
1 parent 5a78607 commit 982853d

File tree

7 files changed

+169
-4
lines changed

7 files changed

+169
-4
lines changed

mypyc/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
MODULE_PREFIX: Final = "CPyModule_" # Cached modules
1616
TYPE_VAR_PREFIX: Final = "CPyTypeVar_" # Type variables when using new-style Python 3.12 syntax
1717
ATTR_PREFIX: Final = "_" # Attributes
18+
FAST_PREFIX: Final = "__mypyc_fast_" # Optimized methods in non-extension classes
1819

1920
ENV_ATTR_NAME: Final = "__mypyc_env__"
2021
NEXT_LABEL_ATTR_NAME: Final = "__mypyc_next_label__"

mypyc/ir/class_ir.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ def __init__(
210210
# per-type free "list" of up to length 1.
211211
self.reuse_freed_instance = False
212212

213+
# Is this a class inheriting from enum.Enum? Such classes can be special-cased.
214+
self.is_enum = False
215+
213216
def __repr__(self) -> str:
214217
return (
215218
"ClassIR("
@@ -410,6 +413,7 @@ def serialize(self) -> JsonDict:
410413
"init_self_leak": self.init_self_leak,
411414
"env_user_function": self.env_user_function.id if self.env_user_function else None,
412415
"reuse_freed_instance": self.reuse_freed_instance,
416+
"is_enum": self.is_enum,
413417
}
414418

415419
@classmethod
@@ -466,6 +470,7 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
466470
ctx.functions[data["env_user_function"]] if data["env_user_function"] else None
467471
)
468472
ir.reuse_freed_instance = data["reuse_freed_instance"]
473+
ir.is_enum = data["is_enum"]
469474

470475
return ir
471476

mypyc/irbuild/function.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
Var,
3030
)
3131
from mypy.types import CallableType, Type, UnboundType, get_proper_type
32-
from mypyc.common import LAMBDA_NAME, PROPSET_PREFIX, SELF_NAME
32+
from mypyc.common import FAST_PREFIX, LAMBDA_NAME, PROPSET_PREFIX, SELF_NAME
3333
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
3434
from mypyc.ir.func_ir import (
3535
FUNC_CLASSMETHOD,
@@ -166,6 +166,7 @@ def gen_func_item(
166166
name: str,
167167
sig: FuncSignature,
168168
cdef: ClassDef | None = None,
169+
make_ext_method: bool = False,
169170
) -> tuple[FuncIR, Value | None]:
170171
"""Generate and return the FuncIR for a given FuncDef.
171172
@@ -217,7 +218,7 @@ def c() -> None:
217218
class_name = None
218219
if cdef:
219220
ir = builder.mapper.type_to_ir[cdef.info]
220-
in_non_ext = not ir.is_ext_class
221+
in_non_ext = not ir.is_ext_class and not make_ext_method
221222
class_name = cdef.name
222223

223224
if is_singledispatch:
@@ -339,6 +340,9 @@ def gen_func_ir(
339340
fitem = fn_info.fitem
340341
assert isinstance(fitem, FuncDef), fitem
341342
func_decl = builder.mapper.func_to_decl[fitem]
343+
if cdef and fn_info.name == FAST_PREFIX + func_decl.name:
344+
# Special-cased version of a method has a separate FuncDecl, use that one.
345+
func_decl = builder.mapper.type_to_ir[cdef.info].method_decls[fn_info.name]
342346
if fn_info.is_decorated or is_singledispatch_main_func:
343347
class_name = None if cdef is None else cdef.name
344348
func_decl = FuncDecl(
@@ -453,6 +457,15 @@ def handle_non_ext_method(
453457

454458
builder.add_to_non_ext_dict(non_ext, name, func_reg, fdef.line)
455459

460+
# If we identified that this non-extension class method can be special-cased for
461+
# direct access during prepare phase, generate a "static" version of it.
462+
class_ir = builder.mapper.type_to_ir[cdef.info]
463+
name = FAST_PREFIX + fdef.name
464+
if name in class_ir.method_decls:
465+
func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef, make_ext_method=True)
466+
class_ir.methods[name] = func_ir
467+
builder.functions.append(func_ir)
468+
456469

457470
def gen_func_ns(builder: IRBuilder) -> str:
458471
"""Generate a namespace for a nested function using its outer function names."""

mypyc/irbuild/ll_builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from mypyc.common import (
1818
BITMAP_BITS,
1919
FAST_ISINSTANCE_MAX_SUBCLASSES,
20+
FAST_PREFIX,
2021
IS_FREE_THREADED,
2122
MAX_LITERAL_SHORT_INT,
2223
MAX_SHORT_INT,
@@ -1171,11 +1172,13 @@ def gen_method_call(
11711172
return self.py_method_call(base, name, arg_values, line, arg_kinds, arg_names)
11721173

11731174
# If the base type is one of ours, do a MethodCall
1175+
fast_name = FAST_PREFIX + name
11741176
if (
11751177
isinstance(base.type, RInstance)
1176-
and base.type.class_ir.is_ext_class
1178+
and (base.type.class_ir.is_ext_class or base.type.class_ir.has_method(fast_name))
11771179
and not base.type.class_ir.builtin_base
11781180
):
1181+
name = name if base.type.class_ir.is_ext_class else fast_name
11791182
if base.type.class_ir.has_method(name):
11801183
decl = base.type.class_ir.method_decl(name)
11811184
if arg_kinds is None:

mypyc/irbuild/prepare.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from mypy.semanal import refers_to_fullname
3939
from mypy.traverser import TraverserVisitor
4040
from mypy.types import Instance, Type, get_proper_type
41-
from mypyc.common import PROPSET_PREFIX, SELF_NAME, get_id_from_name
41+
from mypyc.common import FAST_PREFIX, PROPSET_PREFIX, SELF_NAME, get_id_from_name
4242
from mypyc.crash import catch_errors
4343
from mypyc.errors import Errors
4444
from mypyc.ir.class_ir import ClassIR
@@ -106,6 +106,7 @@ def build_type_map(
106106
class_ir.children = None
107107
mapper.type_to_ir[cdef.info] = class_ir
108108
mapper.symbol_fullnames.add(class_ir.fullname)
109+
class_ir.is_enum = cdef.info.is_enum and len(cdef.info.enum_members) > 0
109110

110111
# Populate structural information in class IR for extension classes.
111112
for module, cdef in classes:
@@ -270,6 +271,36 @@ def prepare_method_def(
270271
ir.property_types[node.name] = decl.sig.ret_type
271272

272273

274+
def prepare_fast_path(
275+
ir: ClassIR,
276+
module_name: str,
277+
cdef: ClassDef,
278+
mapper: Mapper,
279+
node: SymbolNode | None,
280+
options: CompilerOptions,
281+
) -> None:
282+
"""Add fast (direct) variants of methods in non-extension classes."""
283+
if ir.is_enum:
284+
# We check that non-empty enums are implicitly final in mypy, so we
285+
# can generate direct calls to enum methods.
286+
if isinstance(node, OverloadedFuncDef):
287+
if node.is_property:
288+
return
289+
node = node.impl
290+
if not isinstance(node, FuncDef):
291+
# TODO: support decorated methods (at least @classmethod and @staticmethod).
292+
return
293+
# The simplest case is a regular or overloaded method without decorators. In this
294+
# case we can generate practically identical IR method body, but with a signature
295+
# suitable for direct calls (usual non-extension class methods are converted to
296+
# callable classes, and thus have an extra __mypyc_self__ argument).
297+
name = FAST_PREFIX + node.name
298+
sig = mapper.fdef_to_sig(node, options.strict_dunders_typing)
299+
decl = FuncDecl(name, cdef.name, module_name, sig, FUNC_NORMAL)
300+
ir.method_decls[name] = decl
301+
return
302+
303+
273304
def is_valid_multipart_property_def(prop: OverloadedFuncDef) -> bool:
274305
# Checks to ensure supported property decorator semantics
275306
if len(prop.items) != 2:
@@ -579,6 +610,8 @@ def prepare_non_ext_class_def(
579610
else:
580611
prepare_method_def(ir, module_name, cdef, mapper, get_func_def(node.node), options)
581612

613+
prepare_fast_path(ir, module_name, cdef, mapper, node.node, options)
614+
582615
if any(cls in mapper.type_to_ir and mapper.type_to_ir[cls].is_ext_class for cls in info.mro):
583616
errors.error(
584617
"Non-extension classes may not inherit from extension classes", path, cdef.line

mypyc/test-data/irbuild-classes.test

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,3 +1408,63 @@ class TestOverload:
14081408

14091409
def __mypyc_generator_helper__(self, x: Any) -> Any:
14101410
return x
1411+
1412+
[case testEnumFastPath]
1413+
from enum import Enum
1414+
1415+
def test(e: E) -> bool:
1416+
return e.is_one()
1417+
1418+
class E(Enum):
1419+
ONE = 1
1420+
TWO = 2
1421+
1422+
def is_one(self) -> bool:
1423+
return self == E.ONE
1424+
[out]
1425+
def test(e):
1426+
e :: __main__.E
1427+
r0 :: bool
1428+
L0:
1429+
r0 = e.__mypyc_fast_is_one()
1430+
return r0
1431+
def is_one_E_obj.__get__(__mypyc_self__, instance, owner):
1432+
__mypyc_self__, instance, owner, r0 :: object
1433+
r1 :: bit
1434+
r2 :: object
1435+
L0:
1436+
r0 = load_address _Py_NoneStruct
1437+
r1 = instance == r0
1438+
if r1 goto L1 else goto L2 :: bool
1439+
L1:
1440+
return __mypyc_self__
1441+
L2:
1442+
r2 = PyMethod_New(__mypyc_self__, instance)
1443+
return r2
1444+
def is_one_E_obj.__call__(__mypyc_self__, self):
1445+
__mypyc_self__ :: __main__.is_one_E_obj
1446+
self, r0 :: __main__.E
1447+
r1 :: bool
1448+
r2 :: bit
1449+
L0:
1450+
r0 = __main__.E.ONE :: static
1451+
if is_error(r0) goto L1 else goto L2
1452+
L1:
1453+
r1 = raise NameError('value for final name "ONE" was not set')
1454+
unreachable
1455+
L2:
1456+
r2 = self == r0
1457+
return r2
1458+
def E.__mypyc_fast_is_one(self):
1459+
self, r0 :: __main__.E
1460+
r1 :: bool
1461+
r2 :: bit
1462+
L0:
1463+
r0 = __main__.E.ONE :: static
1464+
if is_error(r0) goto L1 else goto L2
1465+
L1:
1466+
r1 = raise NameError('value for final name "ONE" was not set')
1467+
unreachable
1468+
L2:
1469+
r2 = self == r0
1470+
return r2

mypyc/test-data/run-classes.test

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2710,6 +2710,56 @@ from native import Player
27102710
[out]
27112711
Player.MIN = <Player.MIN: 1>
27122712

2713+
[case testEnumMethodCalls]
2714+
from enum import Enum
2715+
from typing import overload, Optional, Union
2716+
2717+
class C:
2718+
def foo(self, x: Test) -> bool:
2719+
assert Test.ONE.is_one()
2720+
assert x.next(2) == Test.THREE
2721+
assert x.prev(2) == Test.ONE
2722+
assert x.enigma(22)
2723+
assert x.enigma("22") == 22
2724+
return x.is_one(inverse=True)
2725+
2726+
class Test(Enum):
2727+
ONE = 1
2728+
TWO = 2
2729+
THREE = 3
2730+
2731+
def is_one(self, *, inverse: bool = False) -> bool:
2732+
if inverse:
2733+
return self != Test.ONE
2734+
return self == Test.ONE
2735+
2736+
@classmethod
2737+
def next(cls, val: int) -> Test:
2738+
return cls(val + 1)
2739+
2740+
@staticmethod
2741+
def prev(val: int) -> Test:
2742+
return Test(val - 1)
2743+
2744+
@overload
2745+
def enigma(self, val: int) -> bool: ...
2746+
@overload
2747+
def enigma(self, val: Optional[str] = None) -> int: ...
2748+
def enigma(self, val: Union[int, str, None] = None) -> Union[int, bool]:
2749+
if isinstance(val, int):
2750+
return self.is_one()
2751+
return 22
2752+
[file driver.py]
2753+
from native import Test, C
2754+
2755+
assert Test.ONE.is_one()
2756+
assert Test.TWO.is_one(inverse=True)
2757+
assert not C().foo(Test.ONE)
2758+
assert Test.next(2) == Test.THREE
2759+
assert Test.prev(2) == Test.ONE
2760+
assert Test.ONE.enigma(22)
2761+
assert Test.ONE.enigma("22") == 22
2762+
27132763
[case testStaticCallsWithUnpackingArgs]
27142764
from typing import Tuple
27152765

0 commit comments

Comments
 (0)