diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 20440d4a26f4..64f2b01a0d43 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -11,6 +11,7 @@ from mypy.nodes import ( ARG_POS, + LDEF, BytesExpr, CallExpr, DictionaryComprehension, @@ -28,7 +29,7 @@ TypeAlias, Var, ) -from mypy.types import LiteralType, TupleType, get_proper_type, get_proper_types +from mypy.types import LiteralType, TupleType, Type, get_proper_type, get_proper_types from mypyc.ir.ops import ( ERR_NEVER, BasicBlock, @@ -1241,3 +1242,28 @@ def get_expr_length_value( return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t) # The expression result is known at compile time, so we can use a constant. return Integer(length, c_pyssize_t_rprimitive if use_pyssize_t else short_int_rprimitive) + + +def _is_supported_forloop_iter(builder: IRBuilder, expr: Expression) -> bool: + if is_sequence_rprimitive(builder.node_type(expr)): + return True + if not isinstance(expr, CallExpr): + return False + if isinstance(expr.callee, RefExpr): + return expr.callee.fullname in { + "builtins.range", + "builtins.enumerate", + "builtins.zip", + "builtins.reversed", + } + elif isinstance(expr.callee, MemberExpr): + return expr.callee.fullname in {"keys", "values", "items"} + return False + + +def _create_iterable_lexpr(index_name: str, index_type: Type) -> NameExpr: + """This helper spoofs a NameExpr to use as the lvalue in one of the for loop helpers.""" + index = NameExpr(index_name) + index.kind = LDEF + index.node = Var(index_name, index_type) + return index diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index e810f11bd079..48aa5c94a8cc 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -56,6 +56,7 @@ bytes_rprimitive, c_int_rprimitive, dict_rprimitive, + float_rprimitive, int16_rprimitive, int32_rprimitive, int64_rprimitive, @@ -69,6 +70,7 @@ is_int64_rprimitive, is_int_rprimitive, is_list_rprimitive, + is_object_rprimitive, is_uint8_rprimitive, list_rprimitive, object_rprimitive, @@ -79,7 +81,10 @@ from mypyc.irbuild.builder import IRBuilder from mypyc.irbuild.constant_fold import constant_fold_expr from mypyc.irbuild.for_helpers import ( + _create_iterable_lexpr, + _is_supported_forloop_iter, comprehension_helper, + for_loop_helper, sequence_from_generator_preallocate_helper, translate_list_comprehension, translate_set_comprehension, @@ -412,29 +417,70 @@ def translate_safe_generator_call( @specialize_function("builtins.any") def translate_any_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: - if ( - len(expr.args) == 1 - and expr.arg_kinds == [ARG_POS] - and isinstance(expr.args[0], GeneratorExpr) - ): - return any_all_helper(builder, expr.args[0], builder.false, lambda x: x, builder.true) + if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]: + arg = expr.args[0] + if isinstance(arg, GeneratorExpr): + return any_all_helper(builder, arg, builder.false, lambda x: x, builder.true) + elif _is_supported_forloop_iter(builder, arg): + retval = Register(bool_rprimitive) + builder.assign(retval, builder.false(), -1) + loop_exit = BasicBlock() + index_name = "__mypyc_any_item__" + + def body_insts() -> None: + true_block = BasicBlock() + false_block = BasicBlock() + builder.add_bool_branch(builder.read(index_reg), true_block, false_block) + builder.activate_block(true_block) + builder.assign(retval, builder.true(), -1) + builder.goto(loop_exit) + builder.activate_block(false_block) + + index_type = builder._analyze_iterable_item_type(arg) + index = _create_iterable_lexpr(index_name, index_type) + index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type] + + for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line) + builder.goto_and_activate(loop_exit) + return retval return None @specialize_function("builtins.all") def translate_all_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: - if ( - len(expr.args) == 1 - and expr.arg_kinds == [ARG_POS] - and isinstance(expr.args[0], GeneratorExpr) - ): - return any_all_helper( - builder, - expr.args[0], - builder.true, - lambda x: builder.unary_op(x, "not", expr.line), - builder.false, - ) + if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]: + arg = expr.args[0] + if isinstance(arg, GeneratorExpr): + return any_all_helper( + builder, + arg, + builder.true, + lambda x: builder.unary_op(x, "not", expr.line), + builder.false, + ) + + elif _is_supported_forloop_iter(builder, arg): + retval = Register(bool_rprimitive) + builder.assign(retval, builder.true(), -1) + loop_exit = BasicBlock() + index_name = "__mypyc_all_item__" + + def body_insts() -> None: + true_block = BasicBlock() + false_block = BasicBlock() + builder.add_bool_branch(builder.read(index_reg), true_block, false_block) + builder.activate_block(false_block) + builder.assign(retval, builder.false(), -1) + builder.goto(loop_exit) + builder.activate_block(true_block) + + index_type = builder._analyze_iterable_item_type(arg) + index = _create_iterable_lexpr(index_name, index_type) + index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type] + + for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line) + builder.goto_and_activate(loop_exit) + return retval return None @@ -470,11 +516,11 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V # - only one or two arguments given (if not, sum() has been given invalid arguments) # - first argument is a Generator (there is no benefit to optimizing the performance of eg. # sum([1, 2, 3]), so non-Generator Iterables are not handled) - if not ( - len(expr.args) in (1, 2) - and expr.arg_kinds[0] == ARG_POS - and isinstance(expr.args[0], GeneratorExpr) - ): + if not (len(expr.args) in (1, 2) and expr.arg_kinds[0] == ARG_POS): + return None + + arg = expr.args[0] + if not isinstance(arg, GeneratorExpr) and not _is_supported_forloop_iter(builder, arg): return None # handle 'start' argument, if given @@ -486,21 +532,51 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V else: start_expr = IntExpr(0) - gen_expr = expr.args[0] - target_type = builder.node_type(expr) - retval = Register(target_type) - builder.assign(retval, builder.coerce(builder.accept(start_expr), target_type, -1), -1) + item_type = builder._analyze_iterable_item_type(arg) + item_rtype = builder.type_to_rtype(item_type) + start_rtype = builder.node_type(start_expr) - def gen_inner_stmts() -> None: - call_expr = builder.accept(gen_expr.left_expr) - builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1) + if item_rtype is start_rtype: + acc_rtype = item_rtype + elif is_float_rprimitive(item_rtype) and is_int_rprimitive(start_rtype): + acc_rtype = float_rprimitive + elif is_bool_rprimitive(item_rtype) and is_int_rprimitive(start_rtype): + acc_rtype = int_rprimitive + elif is_object_rprimitive(item_rtype) and is_int_rprimitive(start_rtype): + acc_rtype = object_rprimitive - loop_params = list( - zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists, gen_expr.is_async) - ) - comprehension_helper(builder, loop_params, gen_inner_stmts, gen_expr.line) + else: + # escape hatch, maybe figure out a better way to handle this whole block + # seeking ideas in review + return None - return retval + retval = Register(acc_rtype) + builder.assign(retval, builder.coerce(builder.accept(start_expr), acc_rtype, -1), -1) + + if isinstance(arg, GeneratorExpr): + + def gen_inner_stmts() -> None: + call_expr = builder.accept(arg.left_expr) + builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1) + + loop_params = list(zip(arg.indices, arg.sequences, arg.condlists, arg.is_async)) + comprehension_helper(builder, loop_params, gen_inner_stmts, arg.line) + + return retval + + else: + index_name = "__mypyc_sum_item__" + + def body_insts() -> None: + total = builder.binary_op(retval, builder.read(index_reg), "+", expr.line) + builder.assign(retval, total, expr.line) + + index_type = builder._analyze_iterable_item_type(arg) + index = _create_iterable_lexpr(index_name, index_type) + index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type] + + for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line) + return retval @specialize_function("dataclasses.field") diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 612f3266fd79..1aed103f097e 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -2812,6 +2812,12 @@ def call_any(l: Iterable[int]) -> bool: def call_all(l: Iterable[int]) -> bool: return all(i == 0 for i in l) +def call_any_helper(l: list[Iterable[int]]) -> bool: + return any([str(i) for i in l]) + +def call_all_helper(l: list[Iterable[int]]) -> bool: + return all([str(i) for i in l]) + [out] def call_any(l): l :: object @@ -2870,6 +2876,118 @@ L6: L7: L8: return r0 +def call_any_helper(l): + l :: list + r0 :: bool + r1 :: native_int + r2 :: list + r3, r4 :: native_int + r5 :: bit + r6, i :: object + r7 :: str + r8, r9, r10 :: native_int + r11 :: bit + r12 :: object + r13, __mypyc_any_item__ :: str + r14 :: bit + r15 :: native_int +L0: + r0 = 0 + r1 = var_object_size l + r2 = PyList_New(r1) + r3 = 0 +L1: + r4 = var_object_size l + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L4 :: bool +L2: + r6 = list_get_item_unsafe l, r3 + i = r6 + r7 = PyObject_Str(i) + CPyList_SetItemUnsafe(r2, r3, r7) +L3: + r8 = r3 + 1 + r3 = r8 + goto L1 +L4: + r9 = 0 +L5: + r10 = var_object_size r2 + r11 = r9 < r10 :: signed + if r11 goto L6 else goto L10 :: bool +L6: + r12 = list_get_item_unsafe r2, r9 + r13 = cast(str, r12) + __mypyc_any_item__ = r13 + r14 = CPyStr_IsTrue(__mypyc_any_item__) + if r14 goto L7 else goto L8 :: bool +L7: + r0 = 1 + goto L11 +L8: +L9: + r15 = r9 + 1 + r9 = r15 + goto L5 +L10: +L11: + return r0 +def call_all_helper(l): + l :: list + r0 :: bool + r1 :: native_int + r2 :: list + r3, r4 :: native_int + r5 :: bit + r6, i :: object + r7 :: str + r8, r9, r10 :: native_int + r11 :: bit + r12 :: object + r13, __mypyc_all_item__ :: str + r14 :: bit + r15 :: native_int +L0: + r0 = 1 + r1 = var_object_size l + r2 = PyList_New(r1) + r3 = 0 +L1: + r4 = var_object_size l + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L4 :: bool +L2: + r6 = list_get_item_unsafe l, r3 + i = r6 + r7 = PyObject_Str(i) + CPyList_SetItemUnsafe(r2, r3, r7) +L3: + r8 = r3 + 1 + r3 = r8 + goto L1 +L4: + r9 = 0 +L5: + r10 = var_object_size r2 + r11 = r9 < r10 :: signed + if r11 goto L6 else goto L10 :: bool +L6: + r12 = list_get_item_unsafe r2, r9 + r13 = cast(str, r12) + __mypyc_all_item__ = r13 + r14 = CPyStr_IsTrue(__mypyc_all_item__) + if r14 goto L8 else goto L7 :: bool +L7: + r0 = 0 + goto L11 +L8: +L9: + r15 = r9 + 1 + r9 = r15 + goto L5 +L10: +L11: + return r0 [case testSum] from typing import Callable, Iterable @@ -2877,6 +2995,12 @@ from typing import Callable, Iterable def call_sum(l: Iterable[int], comparison: Callable[[int], bool]) -> int: return sum(comparison(x) for x in l) +def call_sum_helper(l: Iterable[int], comparison: Callable[[int], bool]): + return sum([comparison(x) for x in l]) + +def call_sum_helper_start(l: Iterable[int]): + return sum([str(i) for i in l], "") + [out] def call_sum(l, comparison): l, comparison :: object @@ -2915,6 +3039,121 @@ L4: r12 = CPy_NoErrOccurred() L5: return r0 +def call_sum_helper(l, comparison): + l, comparison :: object + r0 :: int + r1 :: list + r2, r3 :: object + r4, x :: int + r5 :: object + r6 :: object[1] + r7 :: object_ptr + r8 :: object + r9 :: bool + r10 :: object + r11 :: i32 + r12, r13 :: bit + r14, r15 :: native_int + r16 :: bit + r17 :: object + r18, __mypyc_sum_item__, r19 :: bool + r20, r21 :: int + r22 :: native_int + r23 :: object +L0: + r0 = 0 + r1 = PyList_New(0) + r2 = PyObject_GetIter(l) +L1: + r3 = PyIter_Next(r2) + if is_error(r3) goto L4 else goto L2 +L2: + r4 = unbox(int, r3) + x = r4 + r5 = box(int, x) + r6 = [r5] + r7 = load_address r6 + r8 = PyObject_Vectorcall(comparison, r7, 1, 0) + keep_alive r5 + r9 = unbox(bool, r8) + r10 = box(bool, r9) + r11 = PyList_Append(r1, r10) + r12 = r11 >= 0 :: signed +L3: + goto L1 +L4: + r13 = CPy_NoErrOccurred() +L5: + r14 = 0 +L6: + r15 = var_object_size r1 + r16 = r14 < r15 :: signed + if r16 goto L7 else goto L9 :: bool +L7: + r17 = list_get_item_unsafe r1, r14 + r18 = unbox(bool, r17) + __mypyc_sum_item__ = r18 + r19 = __mypyc_sum_item__ << 1 + r20 = extend r19: builtins.bool to builtins.int + r21 = CPyTagged_Add(r0, r20) + r0 = r21 +L8: + r22 = r14 + 1 + r14 = r22 + goto L6 +L9: + r23 = box(int, r0) + return r23 +def call_sum_helper_start(l): + l :: object + r0, r1 :: str + r2 :: list + r3, r4 :: object + r5, i :: int + r6 :: str + r7 :: i32 + r8, r9 :: bit + r10, r11 :: native_int + r12 :: bit + r13 :: object + r14, __mypyc_sum_item__, r15 :: str + r16 :: native_int +L0: + r0 = '' + r1 = r0 + r2 = PyList_New(0) + r3 = PyObject_GetIter(l) +L1: + r4 = PyIter_Next(r3) + if is_error(r4) goto L4 else goto L2 +L2: + r5 = unbox(int, r4) + i = r5 + r6 = CPyTagged_Str(i) + r7 = PyList_Append(r2, r6) + r8 = r7 >= 0 :: signed +L3: + goto L1 +L4: + r9 = CPy_NoErrOccurred() +L5: + r10 = 0 +L6: + r11 = var_object_size r2 + r12 = r10 < r11 :: signed + if r12 goto L7 else goto L9 :: bool +L7: + r13 = list_get_item_unsafe r2, r10 + r14 = cast(str, r13) + __mypyc_sum_item__ = r14 + r15 = PyUnicode_Concat(r1, __mypyc_sum_item__) + r1 = r15 +L8: + r16 = r10 + 1 + r10 = r16 + goto L6 +L9: + return r1 [case testSetAttr1] from typing import Any, Dict, List diff --git a/mypyc/test-data/run-misc.test b/mypyc/test-data/run-misc.test index 1074906357ee..c12b6f1e0a64 100644 --- a/mypyc/test-data/run-misc.test +++ b/mypyc/test-data/run-misc.test @@ -794,8 +794,18 @@ def call_all(l: Iterable[int], val: int = 0) -> int: res = all(i == val for i in l) return 0 if res else 1 +def call_any_with_for_helper(l: Iterable[int], val: int = 0) -> int: + # this listcomp isnt a reasonable real world use but proves the any for loop specializer is good + res = any([i == val for i in l]) + return 0 if res else 1 + +def call_all_with_for_helper(l: Iterable[int], val: int = 0) -> int: + # this listcomp isnt a reasonable real world use but proves the all for loop specializer is good + res = all([i == val for i in l]) + return 0 if res else 1 + [file driver.py] -from native import call_any, call_all, call_any_nested +from native import call_any, call_all, call_any_nested, call_any_with_for_helper, call_all_with_for_helper zeros = [0, 0, 0] ones = [1, 1, 1] @@ -807,24 +817,42 @@ mixed_101 = [1, 0, 1] mixed_110 = [1, 1, 0] assert call_any([]) == 1 +assert call_any_with_for_helper([]) == 1 assert call_any(zeros) == 0 +assert call_any_with_for_helper(zeros) == 0 assert call_any(ones) == 1 +assert call_any_with_for_helper(ones) == 1 assert call_any(mixed_001) == 0 +assert call_any_with_for_helper(mixed_001) == 0 assert call_any(mixed_010) == 0 +assert call_any_with_for_helper(mixed_010) == 0 assert call_any(mixed_100) == 0 +assert call_any_with_for_helper(mixed_100) == 0 assert call_any(mixed_011) == 0 +assert call_any_with_for_helper(mixed_011) == 0 assert call_any(mixed_101) == 0 +assert call_any_with_for_helper(mixed_101) == 0 assert call_any(mixed_110) == 0 +assert call_any_with_for_helper(mixed_110) == 0 assert call_all([]) == 0 +assert call_all_with_for_helper([]) == 0 assert call_all(zeros) == 0 +assert call_all_with_for_helper(zeros) == 0 assert call_all(ones) == 1 +assert call_all_with_for_helper(ones) == 1 assert call_all(mixed_001) == 1 +assert call_all_with_for_helper(mixed_001) == 1 assert call_all(mixed_010) == 1 +assert call_all_with_for_helper(mixed_010) == 1 assert call_all(mixed_100) == 1 +assert call_all_with_for_helper(mixed_100) == 1 assert call_all(mixed_011) == 1 +assert call_all_with_for_helper(mixed_011) == 1 assert call_all(mixed_101) == 1 +assert call_all_with_for_helper(mixed_101) == 1 assert call_all(mixed_110) == 1 +assert call_all_with_for_helper(mixed_110) == 1 assert call_any_nested([[1, 1, 1], [1, 1], []]) == 1 assert call_any_nested([[1, 1, 1], [0, 1], []]) == 0