Skip to content

Commit e4513ba

Browse files
committed
[mypyc] feat: specialize any and all using for loop helpers if possible
1 parent 139071c commit e4513ba

File tree

4 files changed

+249
-20
lines changed

4 files changed

+249
-20
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from mypy.nodes import (
1313
ARG_POS,
14+
LDEF,
1415
BytesExpr,
1516
CallExpr,
1617
DictionaryComprehension,
@@ -28,7 +29,7 @@
2829
TypeAlias,
2930
Var,
3031
)
31-
from mypy.types import LiteralType, TupleType, get_proper_type, get_proper_types
32+
from mypy.types import LiteralType, TupleType, Type, get_proper_type, get_proper_types
3233
from mypyc.ir.ops import (
3334
ERR_NEVER,
3435
BasicBlock,
@@ -1241,3 +1242,25 @@ def get_expr_length_value(
12411242
return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t)
12421243
# The expression result is known at compile time, so we can use a constant.
12431244
return Integer(length, c_pyssize_t_rprimitive if use_pyssize_t else short_int_rprimitive)
1245+
1246+
1247+
def _is_supported_forloop_iter(builder: IRBuilder, expr: Expression) -> bool:
1248+
if is_sequence_rprimitive(builder.node_type(expr)):
1249+
return True
1250+
return (
1251+
isinstance(expr, CallExpr)
1252+
and (
1253+
(isinstance(expr.callee, RefExpr) and expr.callee.fullname in {
1254+
"builtins.range", "builtins.enumerate", "builtins.zip", "builtins.reversed"
1255+
})
1256+
or (isinstance(expr.callee, MemberExpr) and expr.callee.name in {"keys", "values", "items"})
1257+
)
1258+
)
1259+
1260+
1261+
def _create_iterable_lexpr(index_name: str, index_type: Type) -> NameExpr:
1262+
"""This helper spoofs a NameExpr to use as the lvalue in one of the for loop helpers."""
1263+
index = NameExpr(index_name)
1264+
index.kind = LDEF
1265+
index.node = Var(index_name, index_type)
1266+
return index

mypyc/irbuild/specialize.py

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@
7979
from mypyc.irbuild.builder import IRBuilder
8080
from mypyc.irbuild.constant_fold import constant_fold_expr
8181
from mypyc.irbuild.for_helpers import (
82+
_is_supported_forloop_iter,
83+
_create_iterable_lexpr,
8284
comprehension_helper,
85+
for_loop_helper,
8386
sequence_from_generator_preallocate_helper,
8487
translate_list_comprehension,
8588
translate_set_comprehension,
@@ -412,29 +415,86 @@ def translate_safe_generator_call(
412415

413416
@specialize_function("builtins.any")
414417
def translate_any_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
415-
if (
416-
len(expr.args) == 1
417-
and expr.arg_kinds == [ARG_POS]
418-
and isinstance(expr.args[0], GeneratorExpr)
419-
):
420-
return any_all_helper(builder, expr.args[0], builder.false, lambda x: x, builder.true)
418+
if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
419+
arg = expr.args[0]
420+
if isinstance(arg, GeneratorExpr):
421+
return any_all_helper(builder, arg, builder.false, lambda x: x, builder.true)
422+
elif _is_supported_forloop_iter(builder, arg):
423+
retval = Register(bool_rprimitive)
424+
builder.assign(retval, builder.false(), -1)
425+
loop_exit = BasicBlock()
426+
index_name = "__mypyc_any_item__"
427+
428+
def body_insts() -> None:
429+
true_block = BasicBlock()
430+
false_block = BasicBlock()
431+
builder.add_bool_branch(builder.read(index_reg), true_block, false_block)
432+
builder.activate_block(true_block)
433+
builder.assign(retval, builder.true(), -1)
434+
builder.goto(loop_exit)
435+
builder.activate_block(false_block)
436+
437+
index_type = builder._analyze_iterable_item_type(arg)
438+
index = _create_iterable_lexpr(index_name, index_type)
439+
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type))
440+
441+
for_loop_helper(
442+
builder,
443+
index,
444+
arg,
445+
body_insts,
446+
None,
447+
is_async=False,
448+
line=expr.line,
449+
)
450+
builder.goto_and_activate(loop_exit)
451+
return retval
421452
return None
422453

423454

424455
@specialize_function("builtins.all")
425456
def translate_all_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
426-
if (
427-
len(expr.args) == 1
428-
and expr.arg_kinds == [ARG_POS]
429-
and isinstance(expr.args[0], GeneratorExpr)
430-
):
431-
return any_all_helper(
432-
builder,
433-
expr.args[0],
434-
builder.true,
435-
lambda x: builder.unary_op(x, "not", expr.line),
436-
builder.false,
437-
)
457+
if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
458+
arg = expr.args[0]
459+
if isinstance(arg, GeneratorExpr):
460+
return any_all_helper(
461+
builder,
462+
arg,
463+
builder.true,
464+
lambda x: builder.unary_op(x, "not", expr.line),
465+
builder.false,
466+
)
467+
468+
elif _is_supported_forloop_iter(builder, arg):
469+
retval = Register(bool_rprimitive)
470+
builder.assign(retval, builder.true(), -1)
471+
loop_exit = BasicBlock()
472+
index_name = "__mypyc_all_item__"
473+
474+
def body_insts() -> None:
475+
true_block = BasicBlock()
476+
false_block = BasicBlock()
477+
builder.add_bool_branch(builder.read(index_reg), true_block, false_block)
478+
builder.activate_block(false_block)
479+
builder.assign(retval, builder.false(), -1)
480+
builder.goto(loop_exit)
481+
builder.activate_block(true_block)
482+
483+
index_type = builder._analyze_iterable_item_type(arg)
484+
index = _create_iterable_lexpr(index_name, index_type)
485+
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type))
486+
487+
for_loop_helper(
488+
builder,
489+
index,
490+
arg,
491+
body_insts,
492+
None,
493+
is_async=False,
494+
line=expr.line,
495+
)
496+
builder.goto_and_activate(loop_exit)
497+
return retval
438498
return None
439499

440500

mypyc/test-data/irbuild-basic.test

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2812,6 +2812,12 @@ def call_any(l: Iterable[int]) -> bool:
28122812
def call_all(l: Iterable[int]) -> bool:
28132813
return all(i == 0 for i in l)
28142814

2815+
def call_any_helper(l: list[Iterable[int]]) -> bool:
2816+
return any([str(i) for i in l])
2817+
2818+
def call_all_helper(l: list[Iterable[int]]) -> bool:
2819+
return all([str(i) for i in l])
2820+
28152821
[out]
28162822
def call_any(l):
28172823
l :: object
@@ -2870,6 +2876,118 @@ L6:
28702876
L7:
28712877
L8:
28722878
return r0
2879+
def call_any_helper(l):
2880+
l :: list
2881+
r0 :: bool
2882+
r1 :: native_int
2883+
r2 :: list
2884+
r3, r4 :: native_int
2885+
r5 :: bit
2886+
r6, i :: object
2887+
r7 :: str
2888+
r8, r9, r10 :: native_int
2889+
r11 :: bit
2890+
r12 :: object
2891+
r13, __mypyc_any_item__ :: str
2892+
r14 :: bit
2893+
r15 :: native_int
2894+
L0:
2895+
r0 = 0
2896+
r1 = var_object_size l
2897+
r2 = PyList_New(r1)
2898+
r3 = 0
2899+
L1:
2900+
r4 = var_object_size l
2901+
r5 = r3 < r4 :: signed
2902+
if r5 goto L2 else goto L4 :: bool
2903+
L2:
2904+
r6 = list_get_item_unsafe l, r3
2905+
i = r6
2906+
r7 = PyObject_Str(i)
2907+
CPyList_SetItemUnsafe(r2, r3, r7)
2908+
L3:
2909+
r8 = r3 + 1
2910+
r3 = r8
2911+
goto L1
2912+
L4:
2913+
r9 = 0
2914+
L5:
2915+
r10 = var_object_size r2
2916+
r11 = r9 < r10 :: signed
2917+
if r11 goto L6 else goto L10 :: bool
2918+
L6:
2919+
r12 = list_get_item_unsafe r2, r9
2920+
r13 = cast(str, r12)
2921+
__mypyc_any_item__ = r13
2922+
r14 = CPyStr_IsTrue(__mypyc_any_item__)
2923+
if r14 goto L7 else goto L8 :: bool
2924+
L7:
2925+
r0 = 1
2926+
goto L11
2927+
L8:
2928+
L9:
2929+
r15 = r9 + 1
2930+
r9 = r15
2931+
goto L5
2932+
L10:
2933+
L11:
2934+
return r0
2935+
def call_all_helper(l):
2936+
l :: list
2937+
r0 :: bool
2938+
r1 :: native_int
2939+
r2 :: list
2940+
r3, r4 :: native_int
2941+
r5 :: bit
2942+
r6, i :: object
2943+
r7 :: str
2944+
r8, r9, r10 :: native_int
2945+
r11 :: bit
2946+
r12 :: object
2947+
r13, __mypyc_all_item__ :: str
2948+
r14 :: bit
2949+
r15 :: native_int
2950+
L0:
2951+
r0 = 1
2952+
r1 = var_object_size l
2953+
r2 = PyList_New(r1)
2954+
r3 = 0
2955+
L1:
2956+
r4 = var_object_size l
2957+
r5 = r3 < r4 :: signed
2958+
if r5 goto L2 else goto L4 :: bool
2959+
L2:
2960+
r6 = list_get_item_unsafe l, r3
2961+
i = r6
2962+
r7 = PyObject_Str(i)
2963+
CPyList_SetItemUnsafe(r2, r3, r7)
2964+
L3:
2965+
r8 = r3 + 1
2966+
r3 = r8
2967+
goto L1
2968+
L4:
2969+
r9 = 0
2970+
L5:
2971+
r10 = var_object_size r2
2972+
r11 = r9 < r10 :: signed
2973+
if r11 goto L6 else goto L10 :: bool
2974+
L6:
2975+
r12 = list_get_item_unsafe r2, r9
2976+
r13 = cast(str, r12)
2977+
__mypyc_all_item__ = r13
2978+
r14 = CPyStr_IsTrue(__mypyc_all_item__)
2979+
if r14 goto L8 else goto L7 :: bool
2980+
L7:
2981+
r0 = 0
2982+
goto L11
2983+
L8:
2984+
L9:
2985+
r15 = r9 + 1
2986+
r9 = r15
2987+
goto L5
2988+
L10:
2989+
L11:
2990+
return r0
28732991

28742992
[case testSum]
28752993
from typing import Callable, Iterable

mypyc/test-data/run-misc.test

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,18 @@ def call_all(l: Iterable[int], val: int = 0) -> int:
794794
res = all(i == val for i in l)
795795
return 0 if res else 1
796796

797+
def call_any_with_for_helper(l: Iterable[int], val: int = 0) -> int:
798+
# this listcomp isnt a reasonable real world use but proves the any for loop specializer is good
799+
res = any([i == val for i in l])
800+
return 0 if res else 1
801+
802+
def call_all_with_for_helper(l: Iterable[int], val: int = 0) -> int:
803+
# this listcomp isnt a reasonable real world use but proves the all for loop specializer is good
804+
res = all([i == val for i in l])
805+
return 0 if res else 1
806+
797807
[file driver.py]
798-
from native import call_any, call_all, call_any_nested
808+
from native import call_any, call_all, call_any_nested, call_any_with_for_helper, call_all_with_for_helper
799809

800810
zeros = [0, 0, 0]
801811
ones = [1, 1, 1]
@@ -807,24 +817,42 @@ mixed_101 = [1, 0, 1]
807817
mixed_110 = [1, 1, 0]
808818

809819
assert call_any([]) == 1
820+
assert call_any_with_for_helper([]) == 1
810821
assert call_any(zeros) == 0
822+
assert call_any_with_for_helper(zeros) == 0
811823
assert call_any(ones) == 1
824+
assert call_any_with_for_helper(ones) == 1
812825
assert call_any(mixed_001) == 0
826+
assert call_any_with_for_helper(mixed_001) == 0
813827
assert call_any(mixed_010) == 0
828+
assert call_any_with_for_helper(mixed_010) == 0
814829
assert call_any(mixed_100) == 0
830+
assert call_any_with_for_helper(mixed_100) == 0
815831
assert call_any(mixed_011) == 0
832+
assert call_any_with_for_helper(mixed_011) == 0
816833
assert call_any(mixed_101) == 0
834+
assert call_any_with_for_helper(mixed_101) == 0
817835
assert call_any(mixed_110) == 0
836+
assert call_any_with_for_helper(mixed_110) == 0
818837

819838
assert call_all([]) == 0
839+
assert call_all_with_for_helper([]) == 0
820840
assert call_all(zeros) == 0
841+
assert call_all_with_for_helper(zeros) == 0
821842
assert call_all(ones) == 1
843+
assert call_all_with_for_helper(ones) == 1
822844
assert call_all(mixed_001) == 1
845+
assert call_all_with_for_helper(mixed_001) == 1
823846
assert call_all(mixed_010) == 1
847+
assert call_all_with_for_helper(mixed_010) == 1
824848
assert call_all(mixed_100) == 1
849+
assert call_all_with_for_helper(mixed_100) == 1
825850
assert call_all(mixed_011) == 1
851+
assert call_all_with_for_helper(mixed_011) == 1
826852
assert call_all(mixed_101) == 1
853+
assert call_all_with_for_helper(mixed_101) == 1
827854
assert call_all(mixed_110) == 1
855+
assert call_all_with_for_helper(mixed_110) == 1
828856

829857
assert call_any_nested([[1, 1, 1], [1, 1], []]) == 1
830858
assert call_any_nested([[1, 1, 1], [0, 1], []]) == 0

0 commit comments

Comments
 (0)