diff --git a/mypyc/irbuild/main.py b/mypyc/irbuild/main.py index d2c8924a7298..f08911a1bc4c 100644 --- a/mypyc/irbuild/main.py +++ b/mypyc/irbuild/main.py @@ -38,8 +38,9 @@ def f(x: int) -> int: from mypyc.irbuild.mapper import Mapper from mypyc.irbuild.prebuildvisitor import PreBuildVisitor from mypyc.irbuild.prepare import ( + adjust_generator_classes_of_methods, build_type_map, - create_generator_class_if_needed, + create_generator_class_for_func, find_singledispatch_register_impls, ) from mypyc.irbuild.visitor import IRBuilderVisitor @@ -68,6 +69,7 @@ def build_ir( """ build_type_map(mapper, modules, graph, types, options, errors) + adjust_generator_classes_of_methods(mapper) singledispatch_info = find_singledispatch_register_impls(modules, errors) result: ModuleIRs = {} @@ -87,9 +89,10 @@ def build_ir( if isinstance(fdef, FuncDef): # Make generator class name sufficiently unique. suffix = f"___{fdef.line}" - create_generator_class_if_needed( - module.fullname, None, fdef, mapper, name_suffix=suffix - ) + if fdef.is_coroutine or fdef.is_generator: + create_generator_class_for_func( + module.fullname, None, fdef, mapper, name_suffix=suffix + ) # Construct and configure builder objects (cyclic runtime dependency). visitor = IRBuilderVisitor() diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index 05aa0e45c569..c986499b6f65 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -180,14 +180,7 @@ def fdef_to_sig(self, fdef: FuncDef, strict_dunders_typing: bool) -> FuncSignatu for typ, kind in zip(fdef.type.arg_types, fdef.type.arg_kinds) ] arg_pos_onlys = [name is None for name in fdef.type.arg_names] - # TODO: We could probably support decorators sometimes (static and class method?) - if (fdef.is_coroutine or fdef.is_generator) and not fdef.is_decorated: - # Give a more precise type for generators, so that we can optimize - # code that uses them. They return a generator object, which has a - # specific class. Without this, the type would have to be 'object'. - ret: RType = RInstance(self.fdef_to_generator[fdef]) - else: - ret = self.type_to_rtype(fdef.type.ret_type) + ret = self.type_to_rtype(fdef.type.ret_type) else: # Handle unannotated functions arg_types = [object_rprimitive for _ in fdef.arguments] diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index e4f43b38b0dc..2d0a1a8f03bf 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -196,8 +196,6 @@ def prepare_func_def( mapper: Mapper, options: CompilerOptions, ) -> FuncDecl: - create_generator_class_if_needed(module_name, class_name, fdef, mapper) - kind = ( FUNC_CLASSMETHOD if fdef.is_class @@ -209,38 +207,37 @@ def prepare_func_def( return decl -def create_generator_class_if_needed( +def create_generator_class_for_func( module_name: str, class_name: str | None, fdef: FuncDef, mapper: Mapper, name_suffix: str = "" -) -> None: - """If function is a generator/async function, declare a generator class. +) -> ClassIR: + """For a generator/async function, declare a generator class. Each generator and async function gets a dedicated class that implements the generator protocol with generated methods. """ - if fdef.is_coroutine or fdef.is_generator: - name = "_".join(x for x in [fdef.name, class_name] if x) + "_gen" + name_suffix - cir = ClassIR(name, module_name, is_generated=True, is_final_class=True) - cir.reuse_freed_instance = True - mapper.fdef_to_generator[fdef] = cir + assert fdef.is_coroutine or fdef.is_generator + name = "_".join(x for x in [fdef.name, class_name] if x) + "_gen" + name_suffix + cir = ClassIR(name, module_name, is_generated=True, is_final_class=True) + cir.reuse_freed_instance = True + mapper.fdef_to_generator[fdef] = cir - helper_sig = FuncSignature( - ( - RuntimeArg(SELF_NAME, object_rprimitive), - RuntimeArg("type", object_rprimitive), - RuntimeArg("value", object_rprimitive), - RuntimeArg("traceback", object_rprimitive), - RuntimeArg("arg", object_rprimitive), - # If non-NULL, used to store return value instead of raising StopIteration(retv) - RuntimeArg("stop_iter_ptr", object_pointer_rprimitive), - ), - object_rprimitive, - ) + helper_sig = FuncSignature( + ( + RuntimeArg(SELF_NAME, object_rprimitive), + RuntimeArg("type", object_rprimitive), + RuntimeArg("value", object_rprimitive), + RuntimeArg("traceback", object_rprimitive), + RuntimeArg("arg", object_rprimitive), + # If non-NULL, used to store return value instead of raising StopIteration(retv) + RuntimeArg("stop_iter_ptr", object_pointer_rprimitive), + ), + object_rprimitive, + ) - # The implementation of most generator functionality is behind this magic method. - helper_fn_decl = FuncDecl( - GENERATOR_HELPER_NAME, name, module_name, helper_sig, internal=True - ) - cir.method_decls[helper_fn_decl.name] = helper_fn_decl + # The implementation of most generator functionality is behind this magic method. + helper_fn_decl = FuncDecl(GENERATOR_HELPER_NAME, name, module_name, helper_sig, internal=True) + cir.method_decls[helper_fn_decl.name] = helper_fn_decl + return cir def prepare_method_def( @@ -811,3 +808,22 @@ def registered_impl_from_possible_register_call( if isinstance(node, Decorator): return RegisteredImpl(node.func, dispatch_type) return None + + +def adjust_generator_classes_of_methods(mapper: Mapper) -> None: + """Make optimizations and adjustments to generated generator classes of methods. + + This is a separate pass after type map has been built, since we need all classes + to be processed to analyze class hierarchies. + """ + for fdef, ir in mapper.func_to_decl.items(): + if isinstance(fdef, FuncDef) and (fdef.is_coroutine or fdef.is_generator): + gen_ir = create_generator_class_for_func(ir.module_name, ir.class_name, fdef, mapper) + # TODO: We could probably support decorators sometimes (static and class method?) + if not fdef.is_decorated: + # Give a more precise type for generators, so that we can optimize + # code that uses them. They return a generator object, which has a + # specific class. Without this, the type would have to be 'object'. + ir.sig.ret_type = RInstance(gen_ir) + if ir.bound_sig: + ir.bound_sig.ret_type = RInstance(gen_ir)