Skip to content

Commit b3f3eab

Browse files
committed
[mypyc] feat: ForFilter generator helper for builtins.filter
1 parent 5a78607 commit b3f3eab

File tree

3 files changed

+152
-0
lines changed

3 files changed

+152
-0
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,16 @@ def make_for_loop_generator(
490490
for_list = ForSequence(builder, index, body_block, loop_exit, line, nested)
491491
for_list.init(expr_reg, target_type, reverse=True)
492492
return for_list
493+
494+
elif (
495+
expr.callee.fullname == "builtins.filter"
496+
and len(expr.args) == 2
497+
and all(k == ARG_POS for k in expr.arg_kinds)
498+
):
499+
for_filter = ForFilter(builder, index, body_block, loop_exit, line, nested)
500+
for_filter.init(index, expr.args[0], expr.args[1])
501+
return for_filter
502+
493503
if isinstance(expr, CallExpr) and isinstance(expr.callee, MemberExpr) and not expr.args:
494504
# Special cases for dictionary iterator methods, like dict.items().
495505
rtype = builder.node_type(expr.callee.expr)
@@ -1147,3 +1157,45 @@ def gen_step(self) -> None:
11471157
def gen_cleanup(self) -> None:
11481158
for gen in self.gens:
11491159
gen.gen_cleanup()
1160+
1161+
1162+
class ForFilter(ForGenerator):
1163+
"""Generate optimized IR for a for loop over filter(f, iterable)."""
1164+
1165+
def need_cleanup(self) -> bool:
1166+
# The wrapped for loops might need cleanup. We might generate a
1167+
# redundant cleanup block, but that's okay.
1168+
return True
1169+
1170+
def init(self, index: Lvalue, func: Expression, iterable: Expression) -> None:
1171+
self.filter_func = self.builder.accept(func)
1172+
self.iterable = iterable
1173+
self.index = index
1174+
1175+
self.gen = make_for_loop_generator(
1176+
self.builder, self.index, self.iterable, self.body_block, self.loop_exit, self.line, is_async=False, nested=True
1177+
)
1178+
1179+
def gen_condition(self) -> None:
1180+
builder = self.builder
1181+
line = self.line
1182+
# First, get the next item from the sub-generator
1183+
self.gen.gen_condition()
1184+
# Now, filter: only enter the body if func(item) is truthy
1185+
filter_block = BasicBlock()
1186+
builder.activate_block(filter_block)
1187+
self.gen.begin_body()
1188+
item = builder.read(builder.get_assignment_target(self.index), line)
1189+
# TODO: implement logic to handle c calls of native functions
1190+
result = builder.py_call(self.filter_func, [item], line)
1191+
builder.add_bool_branch(result, self.body_block, self.loop_exit)
1192+
1193+
def begin_body(self) -> None:
1194+
# The item is already assigned to self.index by the sub-generator.
1195+
pass
1196+
1197+
def gen_step(self) -> None:
1198+
self.gen.gen_step()
1199+
1200+
def gen_cleanup(self) -> None:
1201+
self.gen.gen_cleanup()

mypyc/test-data/irbuild-basic.test

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3546,3 +3546,73 @@ L0:
35463546
r2 = PyObject_Vectorcall(r1, 0, 0, 0)
35473547
r3 = box(None, 1)
35483548
return r3
3549+
3550+
[case testForFilter]
3551+
def f(x: int) -> bool:
3552+
return bool(x % 2)
3553+
def g(a: list[int]) -> int:
3554+
s = 0
3555+
for x in filter(f, a):
3556+
s += x
3557+
return s
3558+
[out]
3559+
def f(x):
3560+
x, r0 :: int
3561+
r1 :: bit
3562+
L0:
3563+
r0 = CPyTagged_Remainder(x, 4)
3564+
r1 = r0 != 0
3565+
return r1
3566+
def g(a):
3567+
a :: list
3568+
s :: int
3569+
r0 :: dict
3570+
r1 :: str
3571+
r2 :: object
3572+
r3, r4 :: native_int
3573+
r5 :: bit
3574+
r6 :: object
3575+
r7, x :: int
3576+
r8 :: object
3577+
r9 :: object[1]
3578+
r10 :: object_ptr
3579+
r11 :: object
3580+
r12 :: i32
3581+
r13 :: bit
3582+
r14 :: bool
3583+
r15 :: int
3584+
r16 :: native_int
3585+
L0:
3586+
s = 0
3587+
r0 = __main__.globals :: static
3588+
r1 = 'f'
3589+
r2 = CPyDict_GetItem(r0, r1)
3590+
r3 = 0
3591+
L1:
3592+
r4 = var_object_size a
3593+
r5 = r3 < r4 :: signed
3594+
if r5 goto L3 else goto L5 :: bool
3595+
L2:
3596+
r6 = list_get_item_unsafe a, r3
3597+
r7 = unbox(int, r6)
3598+
x = r7
3599+
r8 = box(int, x)
3600+
r9 = [r8]
3601+
r10 = load_address r9
3602+
r11 = PyObject_Vectorcall(r2, r10, 1, 0)
3603+
keep_alive r8
3604+
r12 = PyObject_IsTrue(r11)
3605+
r13 = r12 >= 0 :: signed
3606+
r14 = truncate r12: i32 to builtins.bool
3607+
if r14 goto L3 else goto L5 :: bool
3608+
L3:
3609+
r15 = CPyTagged_Add(s, x)
3610+
s = r15
3611+
L4:
3612+
r16 = r3 + 1
3613+
r3 = r16
3614+
goto L1
3615+
L5:
3616+
L6:
3617+
return s
3618+

mypyc/test-data/run-loops.test

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,3 +571,33 @@ print([x for x in native.Vector2(4, -5.2)])
571571
[out]
572572
Vector2(x=-2, y=3.1)
573573
\[4, -5.2]
574+
575+
[case testRunForFilter]
576+
def f(a: list[int]) -> int:
577+
s = 0
578+
for x in filter(lambda x: x % 2 == 0, a):
579+
s += x
580+
return s
581+
582+
print(f([1, 2, 3, 4, 5, 6]))
583+
print(f([1, 3, 5]))
584+
print(f([]))
585+
586+
[out]
587+
12
588+
0
589+
0
590+
591+
[case testRunForFilterEdgeCases]
592+
def f(a: list[int]) -> int:
593+
s = 0
594+
for x in filter(lambda x: x > 10, a):
595+
s += x
596+
return s
597+
598+
print(f([5, 15, 25]))
599+
print(f([]))
600+
601+
[out]
602+
40
603+
0

0 commit comments

Comments
 (0)