Skip to content

Commit 33996d8

Browse files
committed
[mypyc] feat: ForMap generator helper for builtins.map
1 parent 5a78607 commit 33996d8

File tree

5 files changed

+793
-0
lines changed

5 files changed

+793
-0
lines changed

mypyc/doc/native_operations.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Functions
3737
* ``slice(start, stop, step)``
3838
* ``globals()``
3939
* ``sorted(obj)``
40+
* ``map(fn, iterable)``
4041

4142
Method decorators
4243
-----------------

mypyc/irbuild/for_helpers.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from mypy.nodes import (
1313
ARG_POS,
14+
LDEF,
1415
CallExpr,
1516
DictionaryComprehension,
1617
Expression,
@@ -22,6 +23,7 @@
2223
SetExpr,
2324
TupleExpr,
2425
TypeAlias,
26+
Var,
2527
)
2628
from mypyc.ir.ops import (
2729
ERR_NEVER,
@@ -490,6 +492,16 @@ def make_for_loop_generator(
490492
for_list = ForSequence(builder, index, body_block, loop_exit, line, nested)
491493
for_list.init(expr_reg, target_type, reverse=True)
492494
return for_list
495+
496+
elif (
497+
expr.callee.fullname == "builtins.map"
498+
and len(expr.args) >= 2
499+
and all(k == ARG_POS for k in expr.arg_kinds)
500+
):
501+
for_map = ForMap(builder, index, body_block, loop_exit, line, nested)
502+
for_map.init(expr.args[0], expr.args[1:])
503+
return for_map
504+
493505
if isinstance(expr, CallExpr) and isinstance(expr.callee, MemberExpr) and not expr.args:
494506
# Special cases for dictionary iterator methods, like dict.items().
495507
rtype = builder.node_type(expr.callee.expr)
@@ -1147,3 +1159,74 @@ def gen_step(self) -> None:
11471159
def gen_cleanup(self) -> None:
11481160
for gen in self.gens:
11491161
gen.gen_cleanup()
1162+
1163+
class ForMap(ForGenerator):
1164+
"""Generate optimized IR for a for loop over map(f, ...)."""
1165+
1166+
def need_cleanup(self) -> bool:
1167+
# The wrapped for loops might need cleanup. We might generate a
1168+
# redundant cleanup block, but that's okay.
1169+
return True
1170+
1171+
def init(self, func: Expression, exprs: list[Expression]) -> None:
1172+
self.func_expr = func
1173+
self.func = self.builder.accept(func)
1174+
self.exprs = exprs
1175+
self.cond_blocks = [BasicBlock() for _ in range(len(exprs) - 1)] + [self.body_block]
1176+
1177+
self.gens: list[ForGenerator] = []
1178+
for i, iterable_expr in enumerate(exprs):
1179+
argname = f"_mypyc_map_arg_{i}"
1180+
var_type = self.builder._analyze_iterable_item_type(iterable_expr)
1181+
name_expr = NameExpr(argname)
1182+
name_expr.kind = LDEF
1183+
name_expr.node = Var(argname, var_type)
1184+
self.builder.add_local_reg(name_expr.node, self.builder.node_type(iterable_expr))
1185+
self.gens.append(
1186+
make_for_loop_generator(
1187+
self.builder,
1188+
name_expr,
1189+
iterable_expr,
1190+
#self.gens[-1].body_block if self.gens else self.body_block,
1191+
self.cond_blocks[i],
1192+
self.loop_exit,
1193+
self.line,
1194+
is_async=False,
1195+
nested=True,
1196+
)
1197+
)
1198+
1199+
def gen_condition(self) -> None:
1200+
for i, gen in enumerate(self.gens):
1201+
gen.gen_condition()
1202+
if i < len(self.gens) - 1:
1203+
self.builder.activate_block(self.cond_blocks[i])
1204+
1205+
def begin_body(self) -> None:
1206+
builder = self.builder
1207+
line = self.line
1208+
1209+
for gen in self.gens:
1210+
gen.begin_body()
1211+
1212+
# This goes here to prevent a circular import
1213+
from mypyc.irbuild.expression import transform_call_expr
1214+
1215+
call_expr = CallExpr(
1216+
self.func_expr,
1217+
#items,
1218+
[gen.index for gen in self.gens],
1219+
[ARG_POS] * len(self.gens),
1220+
[None] * len(self.gens),
1221+
)
1222+
1223+
result = transform_call_expr(builder, call_expr)
1224+
builder.assign(builder.get_assignment_target(self.index), result, line)
1225+
1226+
def gen_step(self) -> None:
1227+
for gen in self.gens:
1228+
gen.gen_step()
1229+
1230+
def gen_cleanup(self) -> None:
1231+
for gen in self.gens:
1232+
gen.gen_cleanup()

mypyc/test-data/fixtures/ir.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
overload, Mapping, Union, Callable, Sequence, FrozenSet, Protocol
88
)
99

10+
from typing_extensions import Self
11+
1012
_T = TypeVar('_T')
1113
T_co = TypeVar('T_co', covariant=True)
1214
T_contra = TypeVar('T_contra', contravariant=True)
@@ -405,3 +407,25 @@ class classmethod: pass
405407
class staticmethod: pass
406408

407409
NotImplemented: Any = ...
410+
411+
_T1 = TypeVar("_T1")
412+
_T2 = TypeVar("_T2")
413+
_T3 = TypeVar("_T3")
414+
415+
class map(Generic[_S]):
416+
@overload
417+
def __new__(cls, func: Callable[[_T1], _S], iterable: Iterable[_T1], /) -> Self: ...
418+
@overload
419+
def __new__(cls, func: Callable[[_T1, _T2], _S], iterable: Iterable[_T1], iter2: Iterable[_T2], /) -> Self: ...
420+
@overload
421+
def __new__(cls, func: Callable[[_T1, _T2, _T3], _S], iterable: Iterable[_T1], iter2: Iterable[_T2], iter3: Iterable[_T3], /) -> Self: ...
422+
def __iter__(self) -> Self: ...
423+
def __next__(self) -> _S: ...
424+
425+
class filter(Generic[_T]):
426+
@overload
427+
def __new__(cls, function: None, iterable: Iterable[_T | None], /) -> Self: ...
428+
@overload
429+
def __new__(cls, function: Callable[[_T], Any], iterable: Iterable[_T], /) -> Self: ...
430+
def __iter__(self) -> Self: ...
431+
def __next__(self) -> _T: ...

0 commit comments

Comments
 (0)