Skip to content

Commit fb9f35e

Browse files
committed
Fix recursive generator call
1 parent 988b762 commit fb9f35e

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

mypyc/irbuild/env_class.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@ def g() -> int:
1818
from __future__ import annotations
1919

2020
from mypy.nodes import Argument, FuncDef, SymbolNode, Var
21-
from mypyc.common import BITMAP_BITS, ENV_ATTR_NAME, SELF_NAME, bitmap_name
21+
from mypyc.common import (
22+
BITMAP_BITS,
23+
ENV_ATTR_NAME,
24+
GENERATOR_ATTRIBUTE_PREFIX,
25+
SELF_NAME,
26+
bitmap_name,
27+
)
2228
from mypyc.ir.class_ir import ClassIR
2329
from mypyc.ir.ops import Call, GetAttr, SetAttr, Value
2430
from mypyc.ir.rtypes import RInstance, bitmap_rprimitive, object_rprimitive
@@ -113,7 +119,7 @@ def load_env_registers(builder: IRBuilder, prefix: str = "") -> None:
113119
# If this is a FuncDef, then make sure to load the FuncDef into its own environment
114120
# class so that the function can be called recursively.
115121
if isinstance(fitem, FuncDef) and fn_info.add_nested_funcs_to_env:
116-
setup_func_for_recursive_call(builder, fitem, fn_info.callable_class)
122+
setup_func_for_recursive_call(builder, fitem, fn_info.callable_class, prefix=prefix)
117123

118124

119125
def load_outer_env(
@@ -234,12 +240,16 @@ def add_vars_to_env(builder: IRBuilder, prefix: str = "") -> None:
234240
# the same name and signature across conditional blocks
235241
# will generate different callable classes, so the callable
236242
# class that gets instantiated must be generic.
243+
if nested_fn.is_generator:
244+
prefix = GENERATOR_ATTRIBUTE_PREFIX
237245
builder.add_var_to_env_class(
238246
nested_fn, object_rprimitive, env_for_func, reassign=False, prefix=prefix
239247
)
240248

241249

242-
def setup_func_for_recursive_call(builder: IRBuilder, fdef: FuncDef, base: ImplicitClass) -> None:
250+
def setup_func_for_recursive_call(
251+
builder: IRBuilder, fdef: FuncDef, base: ImplicitClass, prefix: str = ""
252+
) -> None:
243253
"""Enable calling a nested function (with a callable class) recursively.
244254
245255
Adds the instance of the callable class representing the given
@@ -249,7 +259,8 @@ def setup_func_for_recursive_call(builder: IRBuilder, fdef: FuncDef, base: Impli
249259
"""
250260
# First, set the attribute of the environment class so that GetAttr can be called on it.
251261
prev_env = builder.fn_infos[-2].env_class
252-
prev_env.attributes[fdef.name] = builder.type_to_rtype(fdef.type)
262+
attr_name = prefix + fdef.name
263+
prev_env.attributes[attr_name] = builder.type_to_rtype(fdef.type)
253264

254265
if isinstance(base, GeneratorClass):
255266
# If we are dealing with a generator class, then we need to first get the register
@@ -261,7 +272,7 @@ def setup_func_for_recursive_call(builder: IRBuilder, fdef: FuncDef, base: Impli
261272

262273
# Obtain the instance of the callable class representing the FuncDef, and add it to the
263274
# current environment.
264-
val = builder.add(GetAttr(prev_env_reg, fdef.name, -1))
275+
val = builder.add(GetAttr(prev_env_reg, attr_name, -1))
265276
target = builder.add_local_reg(fdef, object_rprimitive)
266277
builder.assign(target, val, -1)
267278

mypyc/irbuild/generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ class that implements the function (each function gets a separate class).
104104
and top_level
105105
and top_level.add_nested_funcs_to_env
106106
):
107-
setup_func_for_recursive_call(builder, fitem, builder.fn_info.generator_class)
107+
setup_func_for_recursive_call(
108+
builder, fitem, builder.fn_info.generator_class, prefix=GENERATOR_ATTRIBUTE_PREFIX
109+
)
108110
create_switch_for_generator_class(builder)
109111
add_raise_exception_blocks_to_generator_class(builder, fitem.line)
110112

mypyc/test-data/run-generators.test

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,17 @@ def call_nested_decorated(x: int) -> list[int]:
272272
a.append(x)
273273
return a
274274

275+
def call_nested_recursive(x: int) -> Iterator:
276+
def recursive(x: int) -> Iterator:
277+
if x > 0:
278+
yield from recursive(x - 1)
279+
yield x
280+
281+
yield from recursive(x)
282+
275283
def test_call_nested_generator_in_function() -> None:
276284
assert call_nested_decorated(5) == [5, 15]
285+
assert list(call_nested_recursive(5)) == [0, 1, 2, 3, 4, 5]
277286

278287
[case testYieldThrow]
279288
from typing import Generator, Iterable, Any, Union

0 commit comments

Comments
 (0)