diff --git a/mypyc/doc/native_operations.rst b/mypyc/doc/native_operations.rst index 3255dbedd98a..3a6afba7872c 100644 --- a/mypyc/doc/native_operations.rst +++ b/mypyc/doc/native_operations.rst @@ -54,3 +54,4 @@ These variants of statements have custom implementations: * ``for ... in seq:`` (for loop over a sequence) * ``for ... in enumerate(...):`` * ``for ... in zip(...):`` +* ``for ... in filter(...):`` diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 20440d4a26f4..97ffafdc76f7 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -523,6 +523,16 @@ def make_for_loop_generator( for_list = ForSequence(builder, index, body_block, loop_exit, line, nested) for_list.init(expr_reg, target_type, reverse=True) return for_list + + elif ( + expr.callee.fullname == "builtins.filter" + and len(expr.args) == 2 + and all(k == ARG_POS for k in expr.arg_kinds) + ): + for_filter = ForFilter(builder, index, body_block, loop_exit, line, nested) + for_filter.init(index, expr.args[0], expr.args[1]) + return for_filter + if isinstance(expr, CallExpr) and isinstance(expr.callee, MemberExpr) and not expr.args: # Special cases for dictionary iterator methods, like dict.items(). rtype = builder.node_type(expr.callee.expr) @@ -1203,6 +1213,76 @@ def gen_cleanup(self) -> None: gen.gen_cleanup() +class ForFilter(ForGenerator): + """Generate optimized IR for a for loop over filter(f, iterable).""" + + def need_cleanup(self) -> bool: + # The wrapped for loops might need cleanup. We might generate a + # redundant cleanup block, but that's okay. + return True + + def init(self, index: Lvalue, func: Expression, iterable: Expression) -> None: + self.filter_func_def = func + if ( + isinstance(func, NameExpr) + and isinstance(func.node, Var) + and func.node.fullname == "builtins.None" + ): + self.filter_func_val = None + else: + self.filter_func_val = self.builder.accept(func) + self.iterable = iterable + self.index = index + + self.gen = make_for_loop_generator( + self.builder, + self.index, + self.iterable, + self.body_block, + self.loop_exit, + self.line, + is_async=False, + nested=True, + ) + + def gen_condition(self) -> None: + self.gen.gen_condition() + + def begin_body(self) -> None: + # 1. Assign the next item to the loop variable + self.gen.begin_body() + + # 2. Call the filter function + builder = self.builder + line = self.line + item = builder.read(builder.get_assignment_target(self.index), line) + + if self.filter_func_val is None: + result = item + else: + fake_call_expr = CallExpr(self.filter_func_def, [self.index], [ARG_POS], [None]) + + # I put this here to prevent a circular import + from mypyc.irbuild.expression import transform_call_expr + + result = transform_call_expr(builder, fake_call_expr) + # result = builder.accept(fake_call_expr) + + # Now, filter: only enter the body if func(item) is truthy + cont_block, rest_block = BasicBlock(), BasicBlock() + builder.add_bool_branch(result, rest_block, cont_block) + builder.activate_block(cont_block) + builder.nonlocal_control[-1].gen_continue(builder, line) + builder.goto_and_activate(rest_block) + # At this point, the rest of the loop body (user code) will be emitted + + def gen_step(self) -> None: + self.gen.gen_step() + + def gen_cleanup(self) -> None: + self.gen.gen_cleanup() + + def get_expr_length(expr: Expression) -> int | None: if isinstance(expr, (StrExpr, BytesExpr)): return len(expr.value) diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 22a6a5986cbd..60a5caf50a1a 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -411,3 +411,11 @@ class classmethod: pass class staticmethod: pass NotImplemented: Any = ... + +class filter(Generic[_T]): + @overload + def __new__(cls, function: None, iterable: Iterable[_T | None], /) -> Self: ... + @overload + def __new__(cls, function: Callable[[_T], Any], iterable: Iterable[_T], /) -> Self: ... + def __iter__(self) -> Self: ... + def __next__(self) -> _T: ... diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 612f3266fd79..998cd9dd7786 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -3552,6 +3552,278 @@ L0: r3 = box(None, 1) return r3 +[case testForFilterBool] +def f(x: int) -> bool: + return bool(x % 2) +def g(a: list[int]) -> int: + s = 0 + for x in filter(f, a): + s += x + return s +[out] +def f(x): + x, r0 :: int + r1 :: bit +L0: + r0 = CPyTagged_Remainder(x, 4) + r1 = r0 != 0 + return r1 +def g(a): + a :: list + s :: int + r0 :: dict + r1 :: str + r2 :: object + r3, r4 :: native_int + r5 :: bit + r6 :: object + r7, x :: int + r8 :: bool + r9 :: int + r10 :: native_int +L0: + s = 0 + r0 = __main__.globals :: static + r1 = 'f' + r2 = CPyDict_GetItem(r0, r1) + r3 = 0 +L1: + r4 = var_object_size a + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L6 :: bool +L2: + r6 = list_get_item_unsafe a, r3 + r7 = unbox(int, r6) + x = r7 + r8 = f(x) + if r8 goto L4 else goto L3 :: bool +L3: + goto L5 +L4: + r9 = CPyTagged_Add(s, x) + s = r9 +L5: + r10 = r3 + 1 + r3 = r10 + goto L1 +L6: +L7: + return s + +[case testForFilterInt] +def f(x: int) -> int: + return x % 2 +def g(a: list[int]) -> int: + s = 0 + for x in filter(f, a): + s += x + return s +[out] +def f(x): + x, r0 :: int +L0: + r0 = CPyTagged_Remainder(x, 4) + return r0 +def g(a): + a :: list + s :: int + r0 :: dict + r1 :: str + r2 :: object + r3, r4 :: native_int + r5 :: bit + r6 :: object + r7, x, r8 :: int + r9 :: bit + r10 :: int + r11 :: native_int +L0: + s = 0 + r0 = __main__.globals :: static + r1 = 'f' + r2 = CPyDict_GetItem(r0, r1) + r3 = 0 +L1: + r4 = var_object_size a + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L6 :: bool +L2: + r6 = list_get_item_unsafe a, r3 + r7 = unbox(int, r6) + x = r7 + r8 = f(x) + r9 = r8 != 0 + if r9 goto L4 else goto L3 :: bool +L3: + goto L5 +L4: + r10 = CPyTagged_Add(s, x) + s = r10 +L5: + r11 = r3 + 1 + r3 = r11 + goto L1 +L6: +L7: + return s + +[case testForFilterStr] +def f(x: int) -> str: + return str(x % 2) +def g(a: list[int]) -> int: + s = 0 + for x in filter(f, a): + s += x + return s +[out] +def f(x): + x, r0 :: int + r1 :: str +L0: + r0 = CPyTagged_Remainder(x, 4) + r1 = CPyTagged_Str(r0) + return r1 +def g(a): + a :: list + s :: int + r0 :: dict + r1 :: str + r2 :: object + r3, r4 :: native_int + r5 :: bit + r6 :: object + r7, x :: int + r8 :: str + r9 :: bit + r10 :: int + r11 :: native_int +L0: + s = 0 + r0 = __main__.globals :: static + r1 = 'f' + r2 = CPyDict_GetItem(r0, r1) + r3 = 0 +L1: + r4 = var_object_size a + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L6 :: bool +L2: + r6 = list_get_item_unsafe a, r3 + r7 = unbox(int, r6) + x = r7 + r8 = f(x) + r9 = CPyStr_IsTrue(r8) + if r9 goto L4 else goto L3 :: bool +L3: + goto L5 +L4: + r10 = CPyTagged_Add(s, x) + s = r10 +L5: + r11 = r3 + 1 + r3 = r11 + goto L1 +L6: +L7: + return s + +[case testForFilterPrimitiveOp] +def f(a: list[list[int]]) -> int: + s = 0 + for x in filter(len, a): + s += 1 + return s +[out] +def f(a): + a :: list + s :: int + r0 :: object + r1 :: str + r2 :: object + r3, r4 :: native_int + r5 :: bit + r6 :: object + r7, x :: list + r8 :: native_int + r9 :: short_int + r10 :: bit + r11 :: int + r12 :: native_int +L0: + s = 0 + r0 = builtins :: module + r1 = 'len' + r2 = CPyObject_GetAttr(r0, r1) + r3 = 0 +L1: + r4 = var_object_size a + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L6 :: bool +L2: + r6 = list_get_item_unsafe a, r3 + r7 = cast(list, r6) + x = r7 + r8 = var_object_size x + r9 = r8 << 1 + r10 = r9 != 0 + if r10 goto L4 else goto L3 :: bool +L3: + goto L5 +L4: + r11 = CPyTagged_Add(s, 2) + s = r11 +L5: + r12 = r3 + 1 + r3 = r12 + goto L1 +L6: +L7: + return s + +[case testForFilterNone] +def f(a: list[int]) -> int: + c = 0 + for x in filter(None, a): + c += 1 + return 0 + +[out] +def f(a): + a :: list + c :: int + r0, r1 :: native_int + r2 :: bit + r3 :: object + r4, x :: int + r5 :: bit + r6 :: int + r7 :: native_int +L0: + c = 0 + r0 = 0 +L1: + r1 = var_object_size a + r2 = r0 < r1 :: signed + if r2 goto L2 else goto L6 :: bool +L2: + r3 = list_get_item_unsafe a, r0 + r4 = unbox(int, r3) + x = r4 + r5 = x != 0 + if r5 goto L4 else goto L3 :: bool +L3: + goto L5 +L4: + r6 = CPyTagged_Add(c, 2) + c = r6 +L5: + r7 = r0 + 1 + r0 = r7 + goto L1 +L6: +L7: + return 0 + [case testStarArgFastPathTuple] from typing import Any, Callable def deco(fn: Callable[..., Any]) -> Callable[..., Any]: diff --git a/mypyc/test-data/run-loops.test b/mypyc/test-data/run-loops.test index 3cbb07297e6e..f701e1f18e98 100644 --- a/mypyc/test-data/run-loops.test +++ b/mypyc/test-data/run-loops.test @@ -571,3 +571,46 @@ print([x for x in native.Vector2(4, -5.2)]) [out] Vector2(x=-2, y=3.1) \[4, -5.2] + +[case testRunForFilter] +def f(a: list[int]) -> int: + s = 0 + for x in filter(lambda x: x % 2 == 0, a): + s += x + return s +def g(a: list[int]) -> int: + s = 0 + for x in filter(lambda x: x > 10, a): + s += x + return s +def with_none(a: list[int]) -> int: + c = 0 + for x in filter(None, a): + c += 1 + return c +def native_func(x: int) -> int: + return x % 2 +def with_native_func(a: list[int]) -> int: + c = 0 + for x in filter(native_func, a): + c += 1 + return c +def with_primitive(a: list[list[int]]) -> int: + c = 0 + for x in filter(len, a): + c += 1 + return c + +def test_run_for_filter() -> None: + assert f([1, 2, 3, 4, 5, 6]) == 12 + assert f([1, 3, 5]) == 0 + assert f([]) == 0 +def test_run_for_filter_with_none() -> None: + assert with_none([0, 1, 2, 3, 4, 5, 6]) == 6 +def test_run_for_filter_with_native_func() -> None: + assert with_native_func([0, 1, 2, 3, 4, 5, 6]) == 3 +def test_run_for_filter_with_primitive() -> None: + assert with_primitive([[], [0, 1], [], [], [2, 3, 4], [5, 6]]) == 3 +def test_run_for_filter_edge_cases() -> None: + assert g([5, 15, 25]) == 40 + assert g([]) == 0