diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 20440d4a26f4..6f95026ff4ed 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -11,6 +11,7 @@ from mypy.nodes import ( ARG_POS, + LDEF, BytesExpr, CallExpr, DictionaryComprehension, @@ -28,7 +29,7 @@ TypeAlias, Var, ) -from mypy.types import LiteralType, TupleType, get_proper_type, get_proper_types +from mypy.types import LiteralType, TupleType, Type, get_proper_type, get_proper_types from mypyc.ir.ops import ( ERR_NEVER, BasicBlock, @@ -1241,3 +1242,27 @@ def get_expr_length_value( return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t) # The expression result is known at compile time, so we can use a constant. return Integer(length, c_pyssize_t_rprimitive if use_pyssize_t else short_int_rprimitive) + + +def expr_has_specialized_for_helper(builder: IRBuilder, expr: Expression) -> bool: + if is_sequence_rprimitive(builder.node_type(expr)): + return True + if not isinstance(expr, CallExpr): + return False + if isinstance(expr.callee, RefExpr): + return expr.callee.fullname in {"builtins.range", "builtins.enumerate", "builtins.zip"} + elif isinstance(expr.callee, MemberExpr): + return expr.callee.fullname in { + "builtins.dict.keys", + "builtins.dict.values", + "builtins.dict.items", + } + return False + + +def create_synthetic_nameexpr(index_name: str, index_type: Type) -> NameExpr: + """This helper spoofs a NameExpr to use as the lvalue in one of the for loop helpers.""" + index = NameExpr(index_name) + index.kind = LDEF + index.node = Var(index_name, index_type) + return index diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index e810f11bd079..591cf3b0bbfb 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -80,6 +80,9 @@ from mypyc.irbuild.constant_fold import constant_fold_expr from mypyc.irbuild.for_helpers import ( comprehension_helper, + create_synthetic_nameexpr, + expr_has_specialized_for_helper, + for_loop_helper, sequence_from_generator_preallocate_helper, translate_list_comprehension, translate_set_comprehension, @@ -412,29 +415,70 @@ def translate_safe_generator_call( @specialize_function("builtins.any") def translate_any_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: - if ( - len(expr.args) == 1 - and expr.arg_kinds == [ARG_POS] - and isinstance(expr.args[0], GeneratorExpr) - ): - return any_all_helper(builder, expr.args[0], builder.false, lambda x: x, builder.true) + if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]: + arg = expr.args[0] + if isinstance(arg, GeneratorExpr): + return any_all_helper(builder, arg, builder.false, lambda x: x, builder.true) + elif expr_has_specialized_for_helper(builder, arg): + retval = Register(bool_rprimitive) + builder.assign(retval, builder.false(), -1) + loop_exit = BasicBlock() + index_name = "__mypyc_any_item__" + + def body_insts() -> None: + true_block = BasicBlock() + false_block = BasicBlock() + builder.add_bool_branch(builder.read(index_reg), true_block, false_block) + builder.activate_block(true_block) + builder.assign(retval, builder.true(), -1) + builder.goto(loop_exit) + builder.activate_block(false_block) + + index_type = builder._analyze_iterable_item_type(arg) + index = create_synthetic_nameexpr(index_name, index_type) + index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type] + + for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line) + builder.goto_and_activate(loop_exit) + return retval return None @specialize_function("builtins.all") def translate_all_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: - if ( - len(expr.args) == 1 - and expr.arg_kinds == [ARG_POS] - and isinstance(expr.args[0], GeneratorExpr) - ): - return any_all_helper( - builder, - expr.args[0], - builder.true, - lambda x: builder.unary_op(x, "not", expr.line), - builder.false, - ) + if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]: + arg = expr.args[0] + if isinstance(arg, GeneratorExpr): + return any_all_helper( + builder, + arg, + builder.true, + lambda x: builder.unary_op(x, "not", expr.line), + builder.false, + ) + + elif expr_has_specialized_for_helper(builder, arg): + retval = Register(bool_rprimitive) + builder.assign(retval, builder.true(), -1) + loop_exit = BasicBlock() + index_name = "__mypyc_all_item__" + + def body_insts() -> None: + true_block = BasicBlock() + false_block = BasicBlock() + builder.add_bool_branch(builder.read(index_reg), true_block, false_block) + builder.activate_block(false_block) + builder.assign(retval, builder.false(), -1) + builder.goto(loop_exit) + builder.activate_block(true_block) + + index_type = builder._analyze_iterable_item_type(arg) + index = create_synthetic_nameexpr(index_name, index_type) + index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type] + + for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line) + builder.goto_and_activate(loop_exit) + return retval return None diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 612f3266fd79..8bb94eb84ca6 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -2812,6 +2812,12 @@ def call_any(l: Iterable[int]) -> bool: def call_all(l: Iterable[int]) -> bool: return all(i == 0 for i in l) +def call_any_helper(l: list[Iterable[int]]) -> bool: + return any([str(i) for i in l]) + +def call_all_helper(l: list[Iterable[int]]) -> bool: + return all([str(i) for i in l]) + [out] def call_any(l): l :: object @@ -2870,6 +2876,118 @@ L6: L7: L8: return r0 +def call_any_helper(l): + l :: list + r0 :: bool + r1 :: native_int + r2 :: list + r3, r4 :: native_int + r5 :: bit + r6, i :: object + r7 :: str + r8, r9, r10 :: native_int + r11 :: bit + r12 :: object + r13, __mypyc_any_item__ :: str + r14 :: bit + r15 :: native_int +L0: + r0 = 0 + r1 = var_object_size l + r2 = PyList_New(r1) + r3 = 0 +L1: + r4 = var_object_size l + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L4 :: bool +L2: + r6 = list_get_item_unsafe l, r3 + i = r6 + r7 = PyObject_Str(i) + CPyList_SetItemUnsafe(r2, r3, r7) +L3: + r8 = r3 + 1 + r3 = r8 + goto L1 +L4: + r9 = 0 +L5: + r10 = var_object_size r2 + r11 = r9 < r10 :: signed + if r11 goto L6 else goto L10 :: bool +L6: + r12 = list_get_item_unsafe r2, r9 + r13 = cast(str, r12) + __mypyc_any_item__ = r13 + r14 = CPyStr_IsTrue(__mypyc_any_item__) + if r14 goto L7 else goto L8 :: bool +L7: + r0 = 1 + goto L11 +L8: +L9: + r15 = r9 + 1 + r9 = r15 + goto L5 +L10: +L11: + return r0 +def call_all_helper(l): + l :: list + r0 :: bool + r1 :: native_int + r2 :: list + r3, r4 :: native_int + r5 :: bit + r6, i :: object + r7 :: str + r8, r9, r10 :: native_int + r11 :: bit + r12 :: object + r13, __mypyc_all_item__ :: str + r14 :: bit + r15 :: native_int +L0: + r0 = 1 + r1 = var_object_size l + r2 = PyList_New(r1) + r3 = 0 +L1: + r4 = var_object_size l + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L4 :: bool +L2: + r6 = list_get_item_unsafe l, r3 + i = r6 + r7 = PyObject_Str(i) + CPyList_SetItemUnsafe(r2, r3, r7) +L3: + r8 = r3 + 1 + r3 = r8 + goto L1 +L4: + r9 = 0 +L5: + r10 = var_object_size r2 + r11 = r9 < r10 :: signed + if r11 goto L6 else goto L10 :: bool +L6: + r12 = list_get_item_unsafe r2, r9 + r13 = cast(str, r12) + __mypyc_all_item__ = r13 + r14 = CPyStr_IsTrue(__mypyc_all_item__) + if r14 goto L8 else goto L7 :: bool +L7: + r0 = 0 + goto L11 +L8: +L9: + r15 = r9 + 1 + r9 = r15 + goto L5 +L10: +L11: + return r0 [case testSum] from typing import Callable, Iterable diff --git a/mypyc/test-data/run-misc.test b/mypyc/test-data/run-misc.test index 1074906357ee..c12b6f1e0a64 100644 --- a/mypyc/test-data/run-misc.test +++ b/mypyc/test-data/run-misc.test @@ -794,8 +794,18 @@ def call_all(l: Iterable[int], val: int = 0) -> int: res = all(i == val for i in l) return 0 if res else 1 +def call_any_with_for_helper(l: Iterable[int], val: int = 0) -> int: + # this listcomp isnt a reasonable real world use but proves the any for loop specializer is good + res = any([i == val for i in l]) + return 0 if res else 1 + +def call_all_with_for_helper(l: Iterable[int], val: int = 0) -> int: + # this listcomp isnt a reasonable real world use but proves the all for loop specializer is good + res = all([i == val for i in l]) + return 0 if res else 1 + [file driver.py] -from native import call_any, call_all, call_any_nested +from native import call_any, call_all, call_any_nested, call_any_with_for_helper, call_all_with_for_helper zeros = [0, 0, 0] ones = [1, 1, 1] @@ -807,24 +817,42 @@ mixed_101 = [1, 0, 1] mixed_110 = [1, 1, 0] assert call_any([]) == 1 +assert call_any_with_for_helper([]) == 1 assert call_any(zeros) == 0 +assert call_any_with_for_helper(zeros) == 0 assert call_any(ones) == 1 +assert call_any_with_for_helper(ones) == 1 assert call_any(mixed_001) == 0 +assert call_any_with_for_helper(mixed_001) == 0 assert call_any(mixed_010) == 0 +assert call_any_with_for_helper(mixed_010) == 0 assert call_any(mixed_100) == 0 +assert call_any_with_for_helper(mixed_100) == 0 assert call_any(mixed_011) == 0 +assert call_any_with_for_helper(mixed_011) == 0 assert call_any(mixed_101) == 0 +assert call_any_with_for_helper(mixed_101) == 0 assert call_any(mixed_110) == 0 +assert call_any_with_for_helper(mixed_110) == 0 assert call_all([]) == 0 +assert call_all_with_for_helper([]) == 0 assert call_all(zeros) == 0 +assert call_all_with_for_helper(zeros) == 0 assert call_all(ones) == 1 +assert call_all_with_for_helper(ones) == 1 assert call_all(mixed_001) == 1 +assert call_all_with_for_helper(mixed_001) == 1 assert call_all(mixed_010) == 1 +assert call_all_with_for_helper(mixed_010) == 1 assert call_all(mixed_100) == 1 +assert call_all_with_for_helper(mixed_100) == 1 assert call_all(mixed_011) == 1 +assert call_all_with_for_helper(mixed_011) == 1 assert call_all(mixed_101) == 1 +assert call_all_with_for_helper(mixed_101) == 1 assert call_all(mixed_110) == 1 +assert call_all_with_for_helper(mixed_110) == 1 assert call_any_nested([[1, 1, 1], [1, 1], []]) == 1 assert call_any_nested([[1, 1, 1], [0, 1], []]) == 0