-
-
Notifications
You must be signed in to change notification settings - Fork 3k
[mypyc] feat: ForFilter generator helper for builtins.filter
#19643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 36 commits
b3f3eab
67818c6
74b7a6e
eeb09ab
ddc13b8
fc12cea
5ce8148
54ad04e
eae9209
9941d54
71b27ef
5237f0b
d68b833
c9680dc
5bf4b22
c39bb4a
8e43b2e
9dceb9a
7c8053f
8aff832
0d2c019
cec1a5d
5170a10
ba5a978
572793c
55ed2d6
dbbbb57
0bc1d26
7d56fa9
4bf480d
11b04c3
fa54df2
d2edf7b
715ce46
a3b65b3
26db7f5
1197e4e
75f61e3
58798da
f18eca3
26ad0d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -499,6 +499,16 @@ def make_for_loop_generator( | |
for_list = ForSequence(builder, index, body_block, loop_exit, line, nested) | ||
for_list.init(expr_reg, target_type, reverse=True) | ||
return for_list | ||
|
||
elif ( | ||
expr.callee.fullname == "builtins.filter" | ||
and len(expr.args) == 2 | ||
and all(k == ARG_POS for k in expr.arg_kinds) | ||
): | ||
for_filter = ForFilter(builder, index, body_block, loop_exit, line, nested) | ||
for_filter.init(index, expr.args[0], expr.args[1]) | ||
return for_filter | ||
|
||
if isinstance(expr, CallExpr) and isinstance(expr.callee, MemberExpr) and not expr.args: | ||
# Special cases for dictionary iterator methods, like dict.items(). | ||
rtype = builder.node_type(expr.callee.expr) | ||
|
@@ -1180,6 +1190,76 @@ def gen_cleanup(self) -> None: | |
gen.gen_cleanup() | ||
|
||
|
||
class ForFilter(ForGenerator): | ||
"""Generate optimized IR for a for loop over filter(f, iterable).""" | ||
|
||
def need_cleanup(self) -> bool: | ||
# The wrapped for loops might need cleanup. We might generate a | ||
# redundant cleanup block, but that's okay. | ||
return True | ||
|
||
def init(self, index: Lvalue, func: Expression, iterable: Expression) -> None: | ||
self.filter_func_def = func | ||
if ( | ||
isinstance(func, NameExpr) | ||
and isinstance(func.node, Var) | ||
and func.node.fullname == "builtins.None" | ||
): | ||
self.filter_func_val = None | ||
else: | ||
self.filter_func_val = self.builder.accept(func) | ||
self.iterable = iterable | ||
self.index = index | ||
|
||
self.gen = make_for_loop_generator( | ||
self.builder, | ||
self.index, | ||
self.iterable, | ||
self.body_block, | ||
self.loop_exit, | ||
self.line, | ||
is_async=False, | ||
nested=True, | ||
) | ||
|
||
def gen_condition(self) -> None: | ||
self.gen.gen_condition() | ||
|
||
def begin_body(self) -> None: | ||
# 1. Assign the next item to the loop variable | ||
self.gen.begin_body() | ||
|
||
# 2. Call the filter function | ||
builder = self.builder | ||
line = self.line | ||
item = builder.read(builder.get_assignment_target(self.index), line) | ||
|
||
if self.filter_func_val is None: | ||
result = item | ||
else: | ||
fake_call_expr = CallExpr(self.filter_func_def, [self.index], [ARG_POS], [None]) | ||
|
||
# I put this here to prevent a circular import | ||
# from mypyc.irbuild.expression import transform_call_expr | ||
|
||
# result = transform_call_expr(builder, fake_call_expr) | ||
result = builder.accept(fake_call_expr) | ||
|
||
|
||
# Now, filter: only enter the body if func(item) is truthy | ||
cont_block, rest_block = BasicBlock(), BasicBlock() | ||
builder.add_bool_branch(result, rest_block, cont_block) | ||
builder.activate_block(cont_block) | ||
builder.nonlocal_control[-1].gen_continue(builder, line) | ||
builder.goto_and_activate(rest_block) | ||
# At this point, the rest of the loop body (user code) will be emitted | ||
|
||
def gen_step(self) -> None: | ||
self.gen.gen_step() | ||
|
||
def gen_cleanup(self) -> None: | ||
self.gen.gen_cleanup() | ||
|
||
|
||
def get_expr_length(expr: Expression) -> int | None: | ||
if isinstance(expr, (StrExpr, BytesExpr)): | ||
return len(expr.value) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3552,6 +3552,278 @@ L0: | |
r3 = box(None, 1) | ||
return r3 | ||
|
||
[case testForFilterBool] | ||
def f(x: int) -> bool: | ||
return bool(x % 2) | ||
def g(a: list[int]) -> int: | ||
s = 0 | ||
for x in filter(f, a): | ||
s += x | ||
return s | ||
[out] | ||
def f(x): | ||
x, r0 :: int | ||
r1 :: bit | ||
L0: | ||
r0 = CPyTagged_Remainder(x, 4) | ||
r1 = r0 != 0 | ||
return r1 | ||
def g(a): | ||
a :: list | ||
s :: int | ||
r0 :: dict | ||
r1 :: str | ||
r2 :: object | ||
r3, r4 :: native_int | ||
r5 :: bit | ||
r6 :: object | ||
r7, x :: int | ||
r8 :: bool | ||
r9 :: int | ||
r10 :: native_int | ||
L0: | ||
s = 0 | ||
r0 = __main__.globals :: static | ||
r1 = 'f' | ||
r2 = CPyDict_GetItem(r0, r1) | ||
r3 = 0 | ||
L1: | ||
r4 = var_object_size a | ||
r5 = r3 < r4 :: signed | ||
if r5 goto L2 else goto L6 :: bool | ||
L2: | ||
r6 = list_get_item_unsafe a, r3 | ||
r7 = unbox(int, r6) | ||
x = r7 | ||
r8 = f(x) | ||
if r8 goto L4 else goto L3 :: bool | ||
L3: | ||
goto L5 | ||
L4: | ||
r9 = CPyTagged_Add(s, x) | ||
s = r9 | ||
L5: | ||
r10 = r3 + 1 | ||
r3 = r10 | ||
goto L1 | ||
L6: | ||
L7: | ||
return s | ||
|
||
[case testForFilterInt] | ||
def f(x: int) -> int: | ||
return x % 2 | ||
def g(a: list[int]) -> int: | ||
s = 0 | ||
for x in filter(f, a): | ||
s += x | ||
return s | ||
[out] | ||
def f(x): | ||
x, r0 :: int | ||
L0: | ||
r0 = CPyTagged_Remainder(x, 4) | ||
return r0 | ||
def g(a): | ||
a :: list | ||
s :: int | ||
r0 :: dict | ||
r1 :: str | ||
r2 :: object | ||
r3, r4 :: native_int | ||
r5 :: bit | ||
r6 :: object | ||
r7, x, r8 :: int | ||
r9 :: bit | ||
r10 :: int | ||
r11 :: native_int | ||
L0: | ||
s = 0 | ||
r0 = __main__.globals :: static | ||
r1 = 'f' | ||
r2 = CPyDict_GetItem(r0, r1) | ||
r3 = 0 | ||
L1: | ||
r4 = var_object_size a | ||
r5 = r3 < r4 :: signed | ||
if r5 goto L2 else goto L6 :: bool | ||
L2: | ||
r6 = list_get_item_unsafe a, r3 | ||
r7 = unbox(int, r6) | ||
x = r7 | ||
r8 = f(x) | ||
r9 = r8 != 0 | ||
if r9 goto L4 else goto L3 :: bool | ||
L3: | ||
goto L5 | ||
L4: | ||
r10 = CPyTagged_Add(s, x) | ||
s = r10 | ||
L5: | ||
r11 = r3 + 1 | ||
r3 = r11 | ||
goto L1 | ||
L6: | ||
L7: | ||
return s | ||
|
||
[case testForFilterStr] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems specialized enough to move to a run test (see my comments in other PRs about our conventions related to irbuild tests). |
||
def f(x: int) -> str: | ||
return str(x % 2) | ||
def g(a: list[int]) -> int: | ||
s = 0 | ||
for x in filter(f, a): | ||
s += x | ||
return s | ||
[out] | ||
def f(x): | ||
x, r0 :: int | ||
r1 :: str | ||
L0: | ||
r0 = CPyTagged_Remainder(x, 4) | ||
r1 = CPyTagged_Str(r0) | ||
return r1 | ||
def g(a): | ||
a :: list | ||
s :: int | ||
r0 :: dict | ||
r1 :: str | ||
r2 :: object | ||
r3, r4 :: native_int | ||
r5 :: bit | ||
r6 :: object | ||
r7, x :: int | ||
r8 :: str | ||
r9 :: bit | ||
r10 :: int | ||
r11 :: native_int | ||
L0: | ||
s = 0 | ||
r0 = __main__.globals :: static | ||
r1 = 'f' | ||
r2 = CPyDict_GetItem(r0, r1) | ||
r3 = 0 | ||
L1: | ||
r4 = var_object_size a | ||
r5 = r3 < r4 :: signed | ||
if r5 goto L2 else goto L6 :: bool | ||
L2: | ||
r6 = list_get_item_unsafe a, r3 | ||
r7 = unbox(int, r6) | ||
x = r7 | ||
r8 = f(x) | ||
r9 = CPyStr_IsTrue(r8) | ||
if r9 goto L4 else goto L3 :: bool | ||
L3: | ||
goto L5 | ||
L4: | ||
r10 = CPyTagged_Add(s, x) | ||
s = r10 | ||
L5: | ||
r11 = r3 + 1 | ||
r3 = r11 | ||
goto L1 | ||
L6: | ||
L7: | ||
return s | ||
|
||
[case testForFilterPrimitiveOp] | ||
def f(a: list[list[int]]) -> int: | ||
s = 0 | ||
for x in filter(len, a): | ||
s += 1 | ||
return s | ||
[out] | ||
def f(a): | ||
a :: list | ||
s :: int | ||
r0 :: object | ||
r1 :: str | ||
r2 :: object | ||
r3, r4 :: native_int | ||
r5 :: bit | ||
r6 :: object | ||
r7, x :: list | ||
r8 :: native_int | ||
r9 :: short_int | ||
r10 :: bit | ||
r11 :: int | ||
r12 :: native_int | ||
L0: | ||
s = 0 | ||
r0 = builtins :: module | ||
r1 = 'len' | ||
r2 = CPyObject_GetAttr(r0, r1) | ||
r3 = 0 | ||
L1: | ||
r4 = var_object_size a | ||
r5 = r3 < r4 :: signed | ||
if r5 goto L2 else goto L6 :: bool | ||
L2: | ||
r6 = list_get_item_unsafe a, r3 | ||
r7 = cast(list, r6) | ||
x = r7 | ||
r8 = var_object_size x | ||
r9 = r8 << 1 | ||
r10 = r9 != 0 | ||
if r10 goto L4 else goto L3 :: bool | ||
L3: | ||
goto L5 | ||
L4: | ||
r11 = CPyTagged_Add(s, 2) | ||
s = r11 | ||
L5: | ||
r12 = r3 + 1 | ||
r3 = r12 | ||
goto L1 | ||
L6: | ||
L7: | ||
return s | ||
|
||
[case testForFilterNone] | ||
def f(a: list[int]) -> int: | ||
c = 0 | ||
for x in filter(None, a): | ||
c += 1 | ||
return 0 | ||
|
||
[out] | ||
def f(a): | ||
a :: list | ||
c :: int | ||
r0, r1 :: native_int | ||
r2 :: bit | ||
r3 :: object | ||
r4, x :: int | ||
r5 :: bit | ||
r6 :: int | ||
r7 :: native_int | ||
L0: | ||
c = 0 | ||
r0 = 0 | ||
L1: | ||
r1 = var_object_size a | ||
r2 = r0 < r1 :: signed | ||
if r2 goto L2 else goto L6 :: bool | ||
L2: | ||
r3 = list_get_item_unsafe a, r0 | ||
r4 = unbox(int, r3) | ||
x = r4 | ||
r5 = x != 0 | ||
if r5 goto L4 else goto L3 :: bool | ||
L3: | ||
goto L5 | ||
L4: | ||
r6 = CPyTagged_Add(c, 2) | ||
c = r6 | ||
L5: | ||
r7 = r0 + 1 | ||
r0 = r7 | ||
goto L1 | ||
L6: | ||
L7: | ||
return 0 | ||
|
||
[case testStarArgFastPathTuple] | ||
from typing import Any, Callable | ||
def deco(fn: Callable[..., Any]) -> Callable[..., Any]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you considered supporting
list(filter(...))
as well -- this seems quite common (in a follow-up PR)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I actually have that drafted already. but it won't be a special case for
list(filter(...))
it will be a special case for[list|tuple|set](some_builtin_we_have_a_helper_for_in_for_helpers(...))
which will account for any builtin we have ForGenerator helpers forUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fwiw this was part of the intent behind the list-built-from-range tests
I wasn't actually testing that we can build a list from a range, I was preparing IR to reflect how this helper would change the C implementation. Will work for map, filter, range, zip, enumerate, and future ops with special-case gen helpers