diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index dbebc350bb6c..dcc5a306bcde 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -261,7 +261,7 @@ def c() -> None: ) # Re-enter the FuncItem and visit the body of the function this time. - gen_generator_func_body(builder, fn_info, sig, func_reg) + gen_generator_func_body(builder, fn_info, func_reg) else: func_ir, func_reg = gen_func_body(builder, sig, cdef, is_singledispatch) diff --git a/mypyc/irbuild/generator.py b/mypyc/irbuild/generator.py index 0e4b0e3e184a..eec27e1cfb84 100644 --- a/mypyc/irbuild/generator.py +++ b/mypyc/irbuild/generator.py @@ -13,9 +13,9 @@ from typing import Callable from mypy.nodes import ARG_OPT, FuncDef, Var -from mypyc.common import ENV_ATTR_NAME, NEXT_LABEL_ATTR_NAME, SELF_NAME +from mypyc.common import ENV_ATTR_NAME, NEXT_LABEL_ATTR_NAME from mypyc.ir.class_ir import ClassIR -from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature, RuntimeArg +from mypyc.ir.func_ir import FuncDecl, FuncIR from mypyc.ir.ops import ( NO_TRACEBACK_LINE_NO, BasicBlock, @@ -78,9 +78,7 @@ def gen_generator_func( return func_ir, func_reg -def gen_generator_func_body( - builder: IRBuilder, fn_info: FuncInfo, sig: FuncSignature, func_reg: Value | None -) -> None: +def gen_generator_func_body(builder: IRBuilder, fn_info: FuncInfo, func_reg: Value | None) -> None: """Generate IR based on the body of a generator function. Add "__next__", "__iter__" and other generator methods to the generator @@ -88,7 +86,7 @@ class that implements the function (each function gets a separate class). Return the symbol table for the body. """ - builder.enter(fn_info, ret_type=sig.ret_type) + builder.enter(fn_info, ret_type=object_rprimitive) setup_env_for_generator_class(builder) load_outer_envs(builder, builder.fn_info.generator_class) @@ -117,7 +115,7 @@ class that implements the function (each function gets a separate class). args, _, blocks, ret_type, fn_info = builder.leave() - add_methods_to_generator_class(builder, fn_info, sig, args, blocks, fitem.is_coroutine) + add_methods_to_generator_class(builder, fn_info, args, blocks, fitem.is_coroutine) # Evaluate argument defaults in the surrounding scope, since we # calculate them *once* when the function definition is evaluated. @@ -153,10 +151,9 @@ def instantiate_generator_class(builder: IRBuilder) -> Value: def setup_generator_class(builder: IRBuilder) -> ClassIR: - name = f"{builder.fn_info.namespaced_name()}_gen" - - generator_class_ir = ClassIR(name, builder.module_name, is_generated=True, is_final_class=True) - generator_class_ir.reuse_freed_instance = True + mapper = builder.mapper + assert isinstance(builder.fn_info.fitem, FuncDef) + generator_class_ir = mapper.fdef_to_generator[builder.fn_info.fitem] if builder.fn_info.can_merge_generator_and_env_classes(): builder.fn_info.env_class = generator_class_ir else: @@ -216,46 +213,25 @@ def add_raise_exception_blocks_to_generator_class(builder: IRBuilder, line: int) def add_methods_to_generator_class( builder: IRBuilder, fn_info: FuncInfo, - sig: FuncSignature, arg_regs: list[Register], blocks: list[BasicBlock], is_coroutine: bool, ) -> None: - helper_fn_decl = add_helper_to_generator_class(builder, arg_regs, blocks, sig, fn_info) - add_next_to_generator_class(builder, fn_info, helper_fn_decl, sig) - add_send_to_generator_class(builder, fn_info, helper_fn_decl, sig) + helper_fn_decl = add_helper_to_generator_class(builder, arg_regs, blocks, fn_info) + add_next_to_generator_class(builder, fn_info, helper_fn_decl) + add_send_to_generator_class(builder, fn_info, helper_fn_decl) add_iter_to_generator_class(builder, fn_info) - add_throw_to_generator_class(builder, fn_info, helper_fn_decl, sig) + add_throw_to_generator_class(builder, fn_info, helper_fn_decl) add_close_to_generator_class(builder, fn_info) if is_coroutine: add_await_to_generator_class(builder, fn_info) def add_helper_to_generator_class( - builder: IRBuilder, - arg_regs: list[Register], - blocks: list[BasicBlock], - sig: FuncSignature, - fn_info: FuncInfo, + builder: IRBuilder, arg_regs: list[Register], blocks: list[BasicBlock], fn_info: FuncInfo ) -> FuncDecl: """Generates a helper method for a generator class, called by '__next__' and 'throw'.""" - sig = FuncSignature( - ( - RuntimeArg(SELF_NAME, object_rprimitive), - RuntimeArg("type", object_rprimitive), - RuntimeArg("value", object_rprimitive), - RuntimeArg("traceback", object_rprimitive), - RuntimeArg("arg", object_rprimitive), - ), - sig.ret_type, - ) - helper_fn_decl = FuncDecl( - "__mypyc_generator_helper__", - fn_info.generator_class.ir.name, - builder.module_name, - sig, - internal=True, - ) + helper_fn_decl = fn_info.generator_class.ir.method_decls["__mypyc_generator_helper__"] helper_fn_ir = FuncIR( helper_fn_decl, arg_regs, blocks, fn_info.fitem.line, traceback_name=fn_info.fitem.name ) @@ -272,9 +248,7 @@ def add_iter_to_generator_class(builder: IRBuilder, fn_info: FuncInfo) -> None: builder.add(Return(builder.self())) -def add_next_to_generator_class( - builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl, sig: FuncSignature -) -> None: +def add_next_to_generator_class(builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl) -> None: """Generates the '__next__' method for a generator class.""" with builder.enter_method(fn_info.generator_class.ir, "__next__", object_rprimitive, fn_info): none_reg = builder.none_object() @@ -289,9 +263,7 @@ def add_next_to_generator_class( builder.add(Return(result)) -def add_send_to_generator_class( - builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl, sig: FuncSignature -) -> None: +def add_send_to_generator_class(builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl) -> None: """Generates the 'send' method for a generator class.""" with builder.enter_method(fn_info.generator_class.ir, "send", object_rprimitive, fn_info): arg = builder.add_argument("arg", object_rprimitive) @@ -307,9 +279,7 @@ def add_send_to_generator_class( builder.add(Return(result)) -def add_throw_to_generator_class( - builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl, sig: FuncSignature -) -> None: +def add_throw_to_generator_class(builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl) -> None: """Generates the 'throw' method for a generator class.""" with builder.enter_method(fn_info.generator_class.ir, "throw", object_rprimitive, fn_info): typ = builder.add_argument("type", object_rprimitive) diff --git a/mypyc/irbuild/main.py b/mypyc/irbuild/main.py index 7cdc6b686778..894e8f277723 100644 --- a/mypyc/irbuild/main.py +++ b/mypyc/irbuild/main.py @@ -25,7 +25,7 @@ def f(x: int) -> int: from typing import Any, Callable, TypeVar, cast from mypy.build import Graph -from mypy.nodes import ClassDef, Expression, MypyFile +from mypy.nodes import ClassDef, Expression, FuncDef, MypyFile from mypy.state import state from mypy.types import Type from mypyc.analysis.attrdefined import analyze_always_defined_attrs @@ -37,7 +37,11 @@ def f(x: int) -> int: from mypyc.irbuild.builder import IRBuilder from mypyc.irbuild.mapper import Mapper from mypyc.irbuild.prebuildvisitor import PreBuildVisitor -from mypyc.irbuild.prepare import build_type_map, find_singledispatch_register_impls +from mypyc.irbuild.prepare import ( + build_type_map, + create_generator_class_if_needed, + find_singledispatch_register_impls, +) from mypyc.irbuild.visitor import IRBuilderVisitor from mypyc.irbuild.vtable import compute_vtable from mypyc.options import CompilerOptions @@ -76,6 +80,15 @@ def build_ir( pbv = PreBuildVisitor(errors, module, singledispatch_info.decorators_to_remove, types) module.accept(pbv) + # Declare generator classes for nested async functions and generators. + for fdef in pbv.nested_funcs: + if isinstance(fdef, FuncDef): + # Make generator class name sufficiently unique. + suffix = f"___{fdef.line}" + create_generator_class_if_needed( + module.fullname, None, fdef, mapper, name_suffix=suffix + ) + # Construct and configure builder objects (cyclic runtime dependency). visitor = IRBuilderVisitor() builder = IRBuilder( diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index 7c6e03d0037c..4a01255e2d5d 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -64,6 +64,8 @@ def __init__(self, group_map: dict[str, str | None]) -> None: self.type_to_ir: dict[TypeInfo, ClassIR] = {} self.func_to_decl: dict[SymbolNode, FuncDecl] = {} self.symbol_fullnames: set[str] = set() + # The corresponding generator class that implements a generator/async function + self.fdef_to_generator: dict[FuncDef, ClassIR] = {} def type_to_rtype(self, typ: Type | None) -> RType: if typ is None: @@ -171,7 +173,14 @@ def fdef_to_sig(self, fdef: FuncDef, strict_dunders_typing: bool) -> FuncSignatu for typ, kind in zip(fdef.type.arg_types, fdef.type.arg_kinds) ] arg_pos_onlys = [name is None for name in fdef.type.arg_names] - ret = self.type_to_rtype(fdef.type.ret_type) + # TODO: We could probably support decorators sometimes (static and class method?) + if (fdef.is_coroutine or fdef.is_generator) and not fdef.is_decorated: + # Give a more precise type for generators, so that we can optimize + # code that uses them. They return a generator object, which has a + # specific class. Without this, the type would have to be 'object'. + ret: RType = RInstance(self.fdef_to_generator[fdef]) + else: + ret = self.type_to_rtype(fdef.type.ret_type) else: # Handle unannotated functions arg_types = [object_rprimitive for _ in fdef.arguments] diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index 65951999dcf9..147392585b25 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -38,7 +38,7 @@ from mypy.semanal import refers_to_fullname from mypy.traverser import TraverserVisitor from mypy.types import Instance, Type, get_proper_type -from mypyc.common import PROPSET_PREFIX, get_id_from_name +from mypyc.common import PROPSET_PREFIX, SELF_NAME, get_id_from_name from mypyc.crash import catch_errors from mypyc.errors import Errors from mypyc.ir.class_ir import ClassIR @@ -51,7 +51,14 @@ RuntimeArg, ) from mypyc.ir.ops import DeserMaps -from mypyc.ir.rtypes import RInstance, RType, dict_rprimitive, none_rprimitive, tuple_rprimitive +from mypyc.ir.rtypes import ( + RInstance, + RType, + dict_rprimitive, + none_rprimitive, + object_rprimitive, + tuple_rprimitive, +) from mypyc.irbuild.mapper import Mapper from mypyc.irbuild.util import ( get_func_def, @@ -115,7 +122,7 @@ def build_type_map( # Collect all the functions also. We collect from the symbol table # so that we can easily pick out the right copy of a function that - # is conditionally defined. + # is conditionally defined. This doesn't include nested functions! for module in modules: for func in get_module_func_defs(module): prepare_func_def(module.fullname, None, func, mapper, options) @@ -179,6 +186,8 @@ def prepare_func_def( mapper: Mapper, options: CompilerOptions, ) -> FuncDecl: + create_generator_class_if_needed(module_name, class_name, fdef, mapper) + kind = ( FUNC_STATICMETHOD if fdef.is_static @@ -190,6 +199,38 @@ def prepare_func_def( return decl +def create_generator_class_if_needed( + module_name: str, class_name: str | None, fdef: FuncDef, mapper: Mapper, name_suffix: str = "" +) -> None: + """If function is a generator/async function, declare a generator class. + + Each generator and async function gets a dedicated class that implements the + generator protocol with generated methods. + """ + if fdef.is_coroutine or fdef.is_generator: + name = "_".join(x for x in [fdef.name, class_name] if x) + "_gen" + name_suffix + cir = ClassIR(name, module_name, is_generated=True, is_final_class=True) + cir.reuse_freed_instance = True + mapper.fdef_to_generator[fdef] = cir + + helper_sig = FuncSignature( + ( + RuntimeArg(SELF_NAME, object_rprimitive), + RuntimeArg("type", object_rprimitive), + RuntimeArg("value", object_rprimitive), + RuntimeArg("traceback", object_rprimitive), + RuntimeArg("arg", object_rprimitive), + ), + object_rprimitive, + ) + + # The implementation of most generator functionality is behind this magic method. + helper_fn_decl = FuncDecl( + "__mypyc_generator_helper__", name, module_name, helper_sig, internal=True + ) + cir.method_decls[helper_fn_decl.name] = helper_fn_decl + + def prepare_method_def( ir: ClassIR, module_name: str, diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 5c32d8f1a50c..9c7ffb6a3adf 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -48,11 +48,13 @@ ) from mypyc.common import TEMP_ATTR_NAME from mypyc.ir.ops import ( + ERR_NEVER, NAMESPACE_MODULE, NO_TRACEBACK_LINE_NO, Assign, BasicBlock, Branch, + Call, InitStatic, Integer, LoadAddress, @@ -930,16 +932,41 @@ def emit_yield_from_or_await( to_yield_reg = Register(object_rprimitive) received_reg = Register(object_rprimitive) - get_op = coro_op if is_await else iter_op - if isinstance(get_op, PrimitiveDescription): - iter_val = builder.primitive_op(get_op, [val], line) + helper_method = "__mypyc_generator_helper__" + if ( + isinstance(val, (Call, MethodCall)) + and isinstance(val.type, RInstance) + and val.type.class_ir.has_method(helper_method) + ): + # This is a generated native generator class, and we can use a fast path. + # This allows two optimizations: + # 1) No need to call CPy_GetCoro() or iter() since for native generators + # it just returns the generator object (implemented here). + # 2) Instead of calling next(), call generator helper method directly, + # since next() just calls __next__ which calls the helper method. + iter_val: Value = val else: - iter_val = builder.call_c(get_op, [val], line) + get_op = coro_op if is_await else iter_op + if isinstance(get_op, PrimitiveDescription): + iter_val = builder.primitive_op(get_op, [val], line) + else: + iter_val = builder.call_c(get_op, [val], line) iter_reg = builder.maybe_spill_assignable(iter_val) stop_block, main_block, done_block = BasicBlock(), BasicBlock(), BasicBlock() - _y_init = builder.call_c(next_raw_op, [builder.read(iter_reg)], line) + + if isinstance(iter_reg.type, RInstance) and iter_reg.type.class_ir.has_method(helper_method): + # Second fast path optimization: call helper directly (see also comment above). + obj = builder.read(iter_reg) + nn = builder.none_object() + m = MethodCall(obj, helper_method, [nn, nn, nn, nn], line) + # Generators have custom error handling, so disable normal error handling. + m.error_kind = ERR_NEVER + _y_init = builder.add(m) + else: + _y_init = builder.call_c(next_raw_op, [builder.read(iter_reg)], line) + builder.add(Branch(_y_init, stop_block, main_block, Branch.IS_ERROR)) # Try extracting a return value from a StopIteration and return it. @@ -948,7 +975,7 @@ def emit_yield_from_or_await( builder.assign(result, builder.call_c(check_stop_op, [], line), line) # Clear the spilled iterator/coroutine so that it will be freed. # Otherwise, the freeing of the spilled register would likely be delayed. - err = builder.add(LoadErrorValue(object_rprimitive)) + err = builder.add(LoadErrorValue(iter_reg.type)) builder.assign(iter_reg, err, line) builder.goto(done_block) diff --git a/mypyc/test-data/run-async.test b/mypyc/test-data/run-async.test index 2dad720f99cd..b8c4c22daf71 100644 --- a/mypyc/test-data/run-async.test +++ b/mypyc/test-data/run-async.test @@ -2,6 +2,7 @@ [case testRunAsyncBasics] import asyncio +from typing import Callable, Awaitable from testutil import assertRaises @@ -72,6 +73,63 @@ def test_exception() -> None: assert asyncio.run(exc5()) == 3 assert asyncio.run(exc6()) == 3 +async def indirect_call(x: int, c: Callable[[int], Awaitable[int]]) -> int: + return await c(x) + +async def indirect_call_2(a: Awaitable[None]) -> None: + await a + +async def indirect_call_3(a: Awaitable[float]) -> float: + return (await a) + 1.0 + +async def inc(x: int) -> int: + await asyncio.sleep(0) + return x + 1 + +async def ident(x: float, err: bool = False) -> float: + await asyncio.sleep(0.0) + if err: + raise MyError() + return x + float("0.0") + +def test_indirect_call() -> None: + assert asyncio.run(indirect_call(3, inc)) == 4 + + with assertRaises(MyError): + asyncio.run(indirect_call_2(exc1())) + + assert asyncio.run(indirect_call_3(ident(2.0))) == 3.0 + assert asyncio.run(indirect_call_3(ident(-113.0))) == -112.0 + assert asyncio.run(indirect_call_3(ident(-114.0))) == -113.0 + + with assertRaises(MyError): + asyncio.run(indirect_call_3(ident(1.0, True))) + with assertRaises(MyError): + asyncio.run(indirect_call_3(ident(-113.0, True))) + +class C: + def __init__(self, n: int) -> None: + self.n = n + + async def add(self, x: int, err: bool = False) -> int: + await asyncio.sleep(0) + if err: + raise MyError() + return x + self.n + +async def method_call(x: int) -> int: + c = C(5) + return await c.add(x) + +async def method_call_exception() -> int: + c = C(5) + return await c.add(3, err=True) + +def test_async_method_call() -> None: + assert asyncio.run(method_call(3)) == 8 + with assertRaises(MyError): + asyncio.run(method_call_exception()) + [file asyncio/__init__.pyi] async def sleep(t: float) -> None: ... # eh, we could use the real type but it doesn't seem important @@ -563,8 +621,10 @@ def test_bool() -> None: def run(x: object) -> object: ... [case testRunAsyncNestedFunctions] +from __future__ import annotations + import asyncio -from typing import cast, Iterator +from typing import cast, Iterator, overload, Awaitable, Any, TypeVar from testutil import assertRaises @@ -641,6 +701,37 @@ def test_async_def_contains_two_nested_functions() -> None: (5 + 3 + 1), (7 + 4 + 10 + 1) ) +async def async_def_contains_overloaded_async_def(n: int) -> int: + @overload + async def f(x: int) -> int: ... + + @overload + async def f(x: str) -> str: ... + + async def f(x: int | str) -> Any: + return x + + return (await f(n)) + 1 + + +def test_async_def_contains_overloaded_async_def() -> None: + assert asyncio.run(async_def_contains_overloaded_async_def(5)) == 6 + +T = TypeVar("T") + +def deco(f: T) -> T: + return f + +async def async_def_contains_decorated_async_def(n: int) -> int: + @deco + async def f(x: int) -> int: + return x + 2 + + return (await f(n)) + 1 + + +def test_async_def_contains_decorated_async_def() -> None: + assert asyncio.run(async_def_contains_decorated_async_def(7)) == 10 [file asyncio/__init__.pyi] def run(x: object) -> object: ... diff --git a/mypyc/test-data/run-generators.test b/mypyc/test-data/run-generators.test index 9c0b51d58e79..a43aff27dd45 100644 --- a/mypyc/test-data/run-generators.test +++ b/mypyc/test-data/run-generators.test @@ -190,7 +190,9 @@ exit! a exception! ((1,), 'exception!') [case testYieldNested] -from typing import Callable, Generator +from typing import Callable, Generator, Iterator, TypeVar, overload + +from testutil import run_generator def normal(a: int, b: float) -> Callable: def generator(x: int, y: str) -> Generator: @@ -235,15 +237,43 @@ def outer() -> Generator: yield i return recursive(10) -[file driver.py] -from native import normal, generator, triple, another_triple, outer -from testutil import run_generator +def test_return_nested_generator() -> None: + assert run_generator(normal(1, 2.0)(3, '4.00')) == ((1, 2.0, 3, '4.00'), None) + assert run_generator(generator(1)) == ((1, 2, 3), None) + assert run_generator(triple()()) == ((1, 2, 3), None) + assert run_generator(another_triple()()) == ((1,), None) + assert run_generator(outer()) == ((0, 1, 2, 3, 4), None) + +def call_nested(x: int) -> list[int]: + def generator() -> Iterator[int]: + n = int() + 2 + yield x + yield n * x + + a = [] + for x in generator(): + a.append(x) + return a + +T = TypeVar("T") + +def deco(f: T) -> T: + return f + +def call_nested_decorated(x: int) -> list[int]: + @deco + def generator() -> Iterator[int]: + n = int() + 3 + yield x + yield n * x + + a = [] + for x in generator(): + a.append(x) + return a -assert run_generator(normal(1, 2.0)(3, '4.00')) == ((1, 2.0, 3, '4.00'), None) -assert run_generator(generator(1)) == ((1, 2, 3), None) -assert run_generator(triple()()) == ((1, 2, 3), None) -assert run_generator(another_triple()()) == ((1,), None) -assert run_generator(outer()) == ((0, 1, 2, 3, 4), None) +def test_call_nested_generator_in_function() -> None: + assert call_nested_decorated(5) == [5, 15] [case testYieldThrow] from typing import Generator, Iterable, Any, Union