Skip to content

Special-case enum method calls #19634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mypyc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MODULE_PREFIX: Final = "CPyModule_" # Cached modules
TYPE_VAR_PREFIX: Final = "CPyTypeVar_" # Type variables when using new-style Python 3.12 syntax
ATTR_PREFIX: Final = "_" # Attributes
FAST_PREFIX: Final = "__mypyc_fast_" # Optimized methods in non-extension classes

ENV_ATTR_NAME: Final = "__mypyc_env__"
NEXT_LABEL_ATTR_NAME: Final = "__mypyc_next_label__"
Expand Down
5 changes: 5 additions & 0 deletions mypyc/ir/class_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def __init__(
# per-type free "list" of up to length 1.
self.reuse_freed_instance = False

# Is this a class inheriting from enum.Enum? Such classes can be special-cased.
self.is_enum = False

def __repr__(self) -> str:
return (
"ClassIR("
Expand Down Expand Up @@ -410,6 +413,7 @@ def serialize(self) -> JsonDict:
"init_self_leak": self.init_self_leak,
"env_user_function": self.env_user_function.id if self.env_user_function else None,
"reuse_freed_instance": self.reuse_freed_instance,
"is_enum": self.is_enum,
}

@classmethod
Expand Down Expand Up @@ -466,6 +470,7 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
ctx.functions[data["env_user_function"]] if data["env_user_function"] else None
)
ir.reuse_freed_instance = data["reuse_freed_instance"]
ir.is_enum = data["is_enum"]

return ir

Expand Down
17 changes: 15 additions & 2 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
Var,
)
from mypy.types import CallableType, Type, UnboundType, get_proper_type
from mypyc.common import LAMBDA_NAME, PROPSET_PREFIX, SELF_NAME
from mypyc.common import FAST_PREFIX, LAMBDA_NAME, PROPSET_PREFIX, SELF_NAME
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
from mypyc.ir.func_ir import (
FUNC_CLASSMETHOD,
Expand Down Expand Up @@ -166,6 +166,7 @@ def gen_func_item(
name: str,
sig: FuncSignature,
cdef: ClassDef | None = None,
make_ext_method: bool = False,
) -> tuple[FuncIR, Value | None]:
"""Generate and return the FuncIR for a given FuncDef.

Expand Down Expand Up @@ -217,7 +218,7 @@ def c() -> None:
class_name = None
if cdef:
ir = builder.mapper.type_to_ir[cdef.info]
in_non_ext = not ir.is_ext_class
in_non_ext = not ir.is_ext_class and not make_ext_method
class_name = cdef.name

if is_singledispatch:
Expand Down Expand Up @@ -339,6 +340,9 @@ def gen_func_ir(
fitem = fn_info.fitem
assert isinstance(fitem, FuncDef), fitem
func_decl = builder.mapper.func_to_decl[fitem]
if cdef and fn_info.name == FAST_PREFIX + func_decl.name:
# Special-cased version of a method has a separate FuncDecl, use that one.
func_decl = builder.mapper.type_to_ir[cdef.info].method_decls[fn_info.name]
if fn_info.is_decorated or is_singledispatch_main_func:
class_name = None if cdef is None else cdef.name
func_decl = FuncDecl(
Expand Down Expand Up @@ -453,6 +457,15 @@ def handle_non_ext_method(

builder.add_to_non_ext_dict(non_ext, name, func_reg, fdef.line)

# If we identified that this non-extension class method can be special-cased for
# direct access during prepare phase, generate a "static" version of it.
class_ir = builder.mapper.type_to_ir[cdef.info]
name = FAST_PREFIX + fdef.name
if name in class_ir.method_decls:
func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef, make_ext_method=True)
class_ir.methods[name] = func_ir
builder.functions.append(func_ir)


def gen_func_ns(builder: IRBuilder) -> str:
"""Generate a namespace for a nested function using its outer function names."""
Expand Down
5 changes: 4 additions & 1 deletion mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mypyc.common import (
BITMAP_BITS,
FAST_ISINSTANCE_MAX_SUBCLASSES,
FAST_PREFIX,
IS_FREE_THREADED,
MAX_LITERAL_SHORT_INT,
MAX_SHORT_INT,
Expand Down Expand Up @@ -1171,11 +1172,13 @@ def gen_method_call(
return self.py_method_call(base, name, arg_values, line, arg_kinds, arg_names)

# If the base type is one of ours, do a MethodCall
fast_name = FAST_PREFIX + name
if (
isinstance(base.type, RInstance)
and base.type.class_ir.is_ext_class
and (base.type.class_ir.is_ext_class or base.type.class_ir.has_method(fast_name))
and not base.type.class_ir.builtin_base
):
name = name if base.type.class_ir.is_ext_class else fast_name
if base.type.class_ir.has_method(name):
decl = base.type.class_ir.method_decl(name)
if arg_kinds is None:
Expand Down
35 changes: 34 additions & 1 deletion mypyc/irbuild/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from mypy.semanal import refers_to_fullname
from mypy.traverser import TraverserVisitor
from mypy.types import Instance, Type, get_proper_type
from mypyc.common import PROPSET_PREFIX, SELF_NAME, get_id_from_name
from mypyc.common import FAST_PREFIX, PROPSET_PREFIX, SELF_NAME, get_id_from_name
from mypyc.crash import catch_errors
from mypyc.errors import Errors
from mypyc.ir.class_ir import ClassIR
Expand Down Expand Up @@ -106,6 +106,7 @@ def build_type_map(
class_ir.children = None
mapper.type_to_ir[cdef.info] = class_ir
mapper.symbol_fullnames.add(class_ir.fullname)
class_ir.is_enum = cdef.info.is_enum and len(cdef.info.enum_members) > 0

# Populate structural information in class IR for extension classes.
for module, cdef in classes:
Expand Down Expand Up @@ -270,6 +271,36 @@ def prepare_method_def(
ir.property_types[node.name] = decl.sig.ret_type


def prepare_fast_path(
ir: ClassIR,
module_name: str,
cdef: ClassDef,
mapper: Mapper,
node: SymbolNode | None,
options: CompilerOptions,
) -> None:
"""Add fast (direct) variants of methods in non-extension classes."""
if ir.is_enum:
# We check that non-empty enums are implicitly final in mypy, so we
# can generate direct calls to enum methods.
if isinstance(node, OverloadedFuncDef):
if node.is_property:
return
node = node.impl
if not isinstance(node, FuncDef):
# TODO: support decorated methods (at least @classmethod and @staticmethod).
return
# The simplest case is a regular or overloaded method without decorators. In this
# case we can generate practically identical IR method body, but with a signature
# suitable for direct calls (usual non-extension class methods are converted to
# callable classes, and thus have an extra __mypyc_self__ argument).
name = FAST_PREFIX + node.name
sig = mapper.fdef_to_sig(node, options.strict_dunders_typing)
decl = FuncDecl(name, cdef.name, module_name, sig, FUNC_NORMAL)
ir.method_decls[name] = decl
return


def is_valid_multipart_property_def(prop: OverloadedFuncDef) -> bool:
# Checks to ensure supported property decorator semantics
if len(prop.items) != 2:
Expand Down Expand Up @@ -579,6 +610,8 @@ def prepare_non_ext_class_def(
else:
prepare_method_def(ir, module_name, cdef, mapper, get_func_def(node.node), options)

prepare_fast_path(ir, module_name, cdef, mapper, node.node, options)

if any(cls in mapper.type_to_ir and mapper.type_to_ir[cls].is_ext_class for cls in info.mro):
errors.error(
"Non-extension classes may not inherit from extension classes", path, cdef.line
Expand Down
60 changes: 60 additions & 0 deletions mypyc/test-data/irbuild-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -1408,3 +1408,63 @@ class TestOverload:

def __mypyc_generator_helper__(self, x: Any) -> Any:
return x

[case testEnumFastPath]
from enum import Enum

def test(e: E) -> bool:
return e.is_one()

class E(Enum):
ONE = 1
TWO = 2

def is_one(self) -> bool:
return self == E.ONE
[out]
def test(e):
e :: __main__.E
r0 :: bool
L0:
r0 = e.__mypyc_fast_is_one()
return r0
def is_one_E_obj.__get__(__mypyc_self__, instance, owner):
__mypyc_self__, instance, owner, r0 :: object
r1 :: bit
r2 :: object
L0:
r0 = load_address _Py_NoneStruct
r1 = instance == r0
if r1 goto L1 else goto L2 :: bool
L1:
return __mypyc_self__
L2:
r2 = PyMethod_New(__mypyc_self__, instance)
return r2
def is_one_E_obj.__call__(__mypyc_self__, self):
__mypyc_self__ :: __main__.is_one_E_obj
self, r0 :: __main__.E
r1 :: bool
r2 :: bit
L0:
r0 = __main__.E.ONE :: static
if is_error(r0) goto L1 else goto L2
L1:
r1 = raise NameError('value for final name "ONE" was not set')
unreachable
L2:
r2 = self == r0
return r2
def E.__mypyc_fast_is_one(self):
self, r0 :: __main__.E
r1 :: bool
r2 :: bit
L0:
r0 = __main__.E.ONE :: static
if is_error(r0) goto L1 else goto L2
L1:
r1 = raise NameError('value for final name "ONE" was not set')
unreachable
L2:
r2 = self == r0
return r2
50 changes: 50 additions & 0 deletions mypyc/test-data/run-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -2710,6 +2710,56 @@ from native import Player
[out]
Player.MIN = <Player.MIN: 1>

[case testEnumMethodCalls]
from enum import Enum
from typing import overload, Optional, Union

class C:
def foo(self, x: Test) -> bool:
assert Test.ONE.is_one()
assert x.next(2) == Test.THREE
assert x.prev(2) == Test.ONE
assert x.enigma(22)
assert x.enigma("22") == 22
return x.is_one(inverse=True)

class Test(Enum):
ONE = 1
TWO = 2
THREE = 3

def is_one(self, *, inverse: bool = False) -> bool:
if inverse:
return self != Test.ONE
return self == Test.ONE

@classmethod
def next(cls, val: int) -> Test:
return cls(val + 1)

@staticmethod
def prev(val: int) -> Test:
return Test(val - 1)

@overload
def enigma(self, val: int) -> bool: ...
@overload
def enigma(self, val: Optional[str] = None) -> int: ...
def enigma(self, val: Union[int, str, None] = None) -> Union[int, bool]:
if isinstance(val, int):
return self.is_one()
return 22
[file driver.py]
from native import Test, C

assert Test.ONE.is_one()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the test case isn't compiled -- is this intentional? Maybe move these to the compiled part of the test case under some test_<...> function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was kind of intentional, as I wanted to test both compiled and interpreted calls, but I realised I only added one call in the compiled path (see definition of C.foo). I will add more calls there (to the overloaded method one etc).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more asserts to foo() itself (note I call C().foo() manually in the driver).

assert Test.TWO.is_one(inverse=True)
assert not C().foo(Test.ONE)
assert Test.next(2) == Test.THREE
assert Test.prev(2) == Test.ONE
assert Test.ONE.enigma(22)
assert Test.ONE.enigma("22") == 22

[case testStaticCallsWithUnpackingArgs]
from typing import Tuple

Expand Down
Loading