Skip to content

Commit 41c7c9b

Browse files
committed
[mypyc] feat: specialize any and all using for loop helpers if possible
1 parent 139071c commit 41c7c9b

File tree

4 files changed

+238
-26
lines changed

4 files changed

+238
-26
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from mypy.nodes import (
1313
ARG_POS,
14+
LDEF,
1415
BytesExpr,
1516
CallExpr,
1617
DictionaryComprehension,
@@ -28,7 +29,7 @@
2829
TypeAlias,
2930
Var,
3031
)
31-
from mypy.types import LiteralType, TupleType, get_proper_type, get_proper_types
32+
from mypy.types import LiteralType, TupleType, Type, get_proper_type, get_proper_types
3233
from mypyc.ir.ops import (
3334
ERR_NEVER,
3435
BasicBlock,
@@ -1241,3 +1242,27 @@ def get_expr_length_value(
12411242
return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t)
12421243
# The expression result is known at compile time, so we can use a constant.
12431244
return Integer(length, c_pyssize_t_rprimitive if use_pyssize_t else short_int_rprimitive)
1245+
1246+
1247+
def expr_has_specialized_for_helper(builder: IRBuilder, expr: Expression) -> bool:
1248+
if is_sequence_rprimitive(builder.node_type(expr)):
1249+
return True
1250+
if not isinstance(expr, CallExpr):
1251+
return False
1252+
if isinstance(expr.callee, RefExpr):
1253+
return expr.callee.fullname in {"builtins.range", "builtins.enumerate", "builtins.zip"}
1254+
elif isinstance(expr.callee, MemberExpr):
1255+
return expr.callee.fullname in {
1256+
"builtins.dict.keys",
1257+
"builtins.dict.values",
1258+
"builtins.dict.items",
1259+
}
1260+
return False
1261+
1262+
1263+
def create_synthetic_nameexpr(index_name: str, index_type: Type) -> NameExpr:
1264+
"""This helper spoofs a NameExpr to use as the lvalue in one of the for loop helpers."""
1265+
index = NameExpr(index_name)
1266+
index.kind = LDEF
1267+
index.node = Var(index_name, index_type)
1268+
return index

mypyc/irbuild/specialize.py

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@
8080
from mypyc.irbuild.constant_fold import constant_fold_expr
8181
from mypyc.irbuild.for_helpers import (
8282
comprehension_helper,
83+
create_synthetic_nameexpr,
84+
expr_has_specialized_for_helper,
85+
for_loop_helper,
8386
sequence_from_generator_preallocate_helper,
8487
translate_list_comprehension,
8588
translate_set_comprehension,
@@ -412,29 +415,70 @@ def translate_safe_generator_call(
412415

413416
@specialize_function("builtins.any")
414417
def translate_any_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
415-
if (
416-
len(expr.args) == 1
417-
and expr.arg_kinds == [ARG_POS]
418-
and isinstance(expr.args[0], GeneratorExpr)
419-
):
420-
return any_all_helper(builder, expr.args[0], builder.false, lambda x: x, builder.true)
418+
if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
419+
arg = expr.args[0]
420+
if isinstance(arg, GeneratorExpr):
421+
return any_all_helper(builder, arg, builder.false, lambda x: x, builder.true)
422+
elif expr_has_specialized_for_helper(builder, arg):
423+
retval = Register(bool_rprimitive)
424+
builder.assign(retval, builder.false(), -1)
425+
loop_exit = BasicBlock()
426+
index_name = "__mypyc_any_item__"
427+
428+
def body_insts() -> None:
429+
true_block = BasicBlock()
430+
false_block = BasicBlock()
431+
builder.add_bool_branch(builder.read(index_reg), true_block, false_block)
432+
builder.activate_block(true_block)
433+
builder.assign(retval, builder.true(), -1)
434+
builder.goto(loop_exit)
435+
builder.activate_block(false_block)
436+
437+
index_type = builder._analyze_iterable_item_type(arg)
438+
index = create_synthetic_nameexpr(index_name, index_type)
439+
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type]
440+
441+
for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line)
442+
builder.goto_and_activate(loop_exit)
443+
return retval
421444
return None
422445

423446

424447
@specialize_function("builtins.all")
425448
def translate_all_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
426-
if (
427-
len(expr.args) == 1
428-
and expr.arg_kinds == [ARG_POS]
429-
and isinstance(expr.args[0], GeneratorExpr)
430-
):
431-
return any_all_helper(
432-
builder,
433-
expr.args[0],
434-
builder.true,
435-
lambda x: builder.unary_op(x, "not", expr.line),
436-
builder.false,
437-
)
449+
if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
450+
arg = expr.args[0]
451+
if isinstance(arg, GeneratorExpr):
452+
return any_all_helper(
453+
builder,
454+
arg,
455+
builder.true,
456+
lambda x: builder.unary_op(x, "not", expr.line),
457+
builder.false,
458+
)
459+
460+
elif expr_has_specialized_for_helper(builder, arg):
461+
retval = Register(bool_rprimitive)
462+
builder.assign(retval, builder.true(), -1)
463+
loop_exit = BasicBlock()
464+
index_name = "__mypyc_all_item__"
465+
466+
def body_insts() -> None:
467+
true_block = BasicBlock()
468+
false_block = BasicBlock()
469+
builder.add_bool_branch(builder.read(index_reg), true_block, false_block)
470+
builder.activate_block(false_block)
471+
builder.assign(retval, builder.false(), -1)
472+
builder.goto(loop_exit)
473+
builder.activate_block(true_block)
474+
475+
index_type = builder._analyze_iterable_item_type(arg)
476+
index = create_synthetic_nameexpr(index_name, index_type)
477+
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type]
478+
479+
for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line)
480+
builder.goto_and_activate(loop_exit)
481+
return retval
438482
return None
439483

440484

@@ -610,12 +654,9 @@ def translate_isinstance(builder: IRBuilder, expr: CallExpr, callee: RefExpr) ->
610654
return builder.builder.isinstance_helper(obj, irs, expr.line)
611655

612656
if isinstance(type_expr, RefExpr):
613-
node = type_expr.node
614-
if node:
615-
desc = isinstance_primitives.get(node.fullname)
616-
if desc:
617-
obj = builder.accept(obj_expr)
618-
return builder.primitive_op(desc, [obj], expr.line)
657+
if node := type_expr.node:
658+
if desc := isinstance_primitives.get(node.fullname):
659+
return builder.primitive_op(desc, [builder.accept(obj_expr)], expr.line)
619660

620661
elif isinstance(type_expr, TupleExpr):
621662
node_names: list[str] = []

mypyc/test-data/irbuild-basic.test

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2812,6 +2812,12 @@ def call_any(l: Iterable[int]) -> bool:
28122812
def call_all(l: Iterable[int]) -> bool:
28132813
return all(i == 0 for i in l)
28142814

2815+
def call_any_helper(l: list[Iterable[int]]) -> bool:
2816+
return any([str(i) for i in l])
2817+
2818+
def call_all_helper(l: list[Iterable[int]]) -> bool:
2819+
return all([str(i) for i in l])
2820+
28152821
[out]
28162822
def call_any(l):
28172823
l :: object
@@ -2870,6 +2876,118 @@ L6:
28702876
L7:
28712877
L8:
28722878
return r0
2879+
def call_any_helper(l):
2880+
l :: list
2881+
r0 :: bool
2882+
r1 :: native_int
2883+
r2 :: list
2884+
r3, r4 :: native_int
2885+
r5 :: bit
2886+
r6, i :: object
2887+
r7 :: str
2888+
r8, r9, r10 :: native_int
2889+
r11 :: bit
2890+
r12 :: object
2891+
r13, __mypyc_any_item__ :: str
2892+
r14 :: bit
2893+
r15 :: native_int
2894+
L0:
2895+
r0 = 0
2896+
r1 = var_object_size l
2897+
r2 = PyList_New(r1)
2898+
r3 = 0
2899+
L1:
2900+
r4 = var_object_size l
2901+
r5 = r3 < r4 :: signed
2902+
if r5 goto L2 else goto L4 :: bool
2903+
L2:
2904+
r6 = list_get_item_unsafe l, r3
2905+
i = r6
2906+
r7 = PyObject_Str(i)
2907+
CPyList_SetItemUnsafe(r2, r3, r7)
2908+
L3:
2909+
r8 = r3 + 1
2910+
r3 = r8
2911+
goto L1
2912+
L4:
2913+
r9 = 0
2914+
L5:
2915+
r10 = var_object_size r2
2916+
r11 = r9 < r10 :: signed
2917+
if r11 goto L6 else goto L10 :: bool
2918+
L6:
2919+
r12 = list_get_item_unsafe r2, r9
2920+
r13 = cast(str, r12)
2921+
__mypyc_any_item__ = r13
2922+
r14 = CPyStr_IsTrue(__mypyc_any_item__)
2923+
if r14 goto L7 else goto L8 :: bool
2924+
L7:
2925+
r0 = 1
2926+
goto L11
2927+
L8:
2928+
L9:
2929+
r15 = r9 + 1
2930+
r9 = r15
2931+
goto L5
2932+
L10:
2933+
L11:
2934+
return r0
2935+
def call_all_helper(l):
2936+
l :: list
2937+
r0 :: bool
2938+
r1 :: native_int
2939+
r2 :: list
2940+
r3, r4 :: native_int
2941+
r5 :: bit
2942+
r6, i :: object
2943+
r7 :: str
2944+
r8, r9, r10 :: native_int
2945+
r11 :: bit
2946+
r12 :: object
2947+
r13, __mypyc_all_item__ :: str
2948+
r14 :: bit
2949+
r15 :: native_int
2950+
L0:
2951+
r0 = 1
2952+
r1 = var_object_size l
2953+
r2 = PyList_New(r1)
2954+
r3 = 0
2955+
L1:
2956+
r4 = var_object_size l
2957+
r5 = r3 < r4 :: signed
2958+
if r5 goto L2 else goto L4 :: bool
2959+
L2:
2960+
r6 = list_get_item_unsafe l, r3
2961+
i = r6
2962+
r7 = PyObject_Str(i)
2963+
CPyList_SetItemUnsafe(r2, r3, r7)
2964+
L3:
2965+
r8 = r3 + 1
2966+
r3 = r8
2967+
goto L1
2968+
L4:
2969+
r9 = 0
2970+
L5:
2971+
r10 = var_object_size r2
2972+
r11 = r9 < r10 :: signed
2973+
if r11 goto L6 else goto L10 :: bool
2974+
L6:
2975+
r12 = list_get_item_unsafe r2, r9
2976+
r13 = cast(str, r12)
2977+
__mypyc_all_item__ = r13
2978+
r14 = CPyStr_IsTrue(__mypyc_all_item__)
2979+
if r14 goto L8 else goto L7 :: bool
2980+
L7:
2981+
r0 = 0
2982+
goto L11
2983+
L8:
2984+
L9:
2985+
r15 = r9 + 1
2986+
r9 = r15
2987+
goto L5
2988+
L10:
2989+
L11:
2990+
return r0
28732991

28742992
[case testSum]
28752993
from typing import Callable, Iterable

mypyc/test-data/run-misc.test

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,18 @@ def call_all(l: Iterable[int], val: int = 0) -> int:
794794
res = all(i == val for i in l)
795795
return 0 if res else 1
796796

797+
def call_any_with_for_helper(l: Iterable[int], val: int = 0) -> int:
798+
# this listcomp isnt a reasonable real world use but proves the any for loop specializer is good
799+
res = any([i == val for i in l])
800+
return 0 if res else 1
801+
802+
def call_all_with_for_helper(l: Iterable[int], val: int = 0) -> int:
803+
# this listcomp isnt a reasonable real world use but proves the all for loop specializer is good
804+
res = all([i == val for i in l])
805+
return 0 if res else 1
806+
797807
[file driver.py]
798-
from native import call_any, call_all, call_any_nested
808+
from native import call_any, call_all, call_any_nested, call_any_with_for_helper, call_all_with_for_helper
799809

800810
zeros = [0, 0, 0]
801811
ones = [1, 1, 1]
@@ -807,24 +817,42 @@ mixed_101 = [1, 0, 1]
807817
mixed_110 = [1, 1, 0]
808818

809819
assert call_any([]) == 1
820+
assert call_any_with_for_helper([]) == 1
810821
assert call_any(zeros) == 0
822+
assert call_any_with_for_helper(zeros) == 0
811823
assert call_any(ones) == 1
824+
assert call_any_with_for_helper(ones) == 1
812825
assert call_any(mixed_001) == 0
826+
assert call_any_with_for_helper(mixed_001) == 0
813827
assert call_any(mixed_010) == 0
828+
assert call_any_with_for_helper(mixed_010) == 0
814829
assert call_any(mixed_100) == 0
830+
assert call_any_with_for_helper(mixed_100) == 0
815831
assert call_any(mixed_011) == 0
832+
assert call_any_with_for_helper(mixed_011) == 0
816833
assert call_any(mixed_101) == 0
834+
assert call_any_with_for_helper(mixed_101) == 0
817835
assert call_any(mixed_110) == 0
836+
assert call_any_with_for_helper(mixed_110) == 0
818837

819838
assert call_all([]) == 0
839+
assert call_all_with_for_helper([]) == 0
820840
assert call_all(zeros) == 0
841+
assert call_all_with_for_helper(zeros) == 0
821842
assert call_all(ones) == 1
843+
assert call_all_with_for_helper(ones) == 1
822844
assert call_all(mixed_001) == 1
845+
assert call_all_with_for_helper(mixed_001) == 1
823846
assert call_all(mixed_010) == 1
847+
assert call_all_with_for_helper(mixed_010) == 1
824848
assert call_all(mixed_100) == 1
849+
assert call_all_with_for_helper(mixed_100) == 1
825850
assert call_all(mixed_011) == 1
851+
assert call_all_with_for_helper(mixed_011) == 1
826852
assert call_all(mixed_101) == 1
853+
assert call_all_with_for_helper(mixed_101) == 1
827854
assert call_all(mixed_110) == 1
855+
assert call_all_with_for_helper(mixed_110) == 1
828856

829857
assert call_any_nested([[1, 1, 1], [1, 1], []]) == 1
830858
assert call_any_nested([[1, 1, 1], [0, 1], []]) == 0

0 commit comments

Comments
 (0)