Skip to content

Commit c1def27

Browse files
committed
specialize sum from for helper
1 parent 7d98f34 commit c1def27

File tree

2 files changed

+179
-17
lines changed

2 files changed

+179
-17
lines changed

mypyc/irbuild/specialize.py

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
bytes_rprimitive,
5656
c_int_rprimitive,
5757
dict_rprimitive,
58+
float_rprimitive,
5859
int16_rprimitive,
5960
int32_rprimitive,
6061
int64_rprimitive,
@@ -68,6 +69,7 @@
6869
is_int64_rprimitive,
6970
is_int_rprimitive,
7071
is_list_rprimitive,
72+
is_object_rprimitive,
7173
is_uint8_rprimitive,
7274
list_rprimitive,
7375
object_rprimitive,
@@ -511,11 +513,11 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V
511513
# - only one or two arguments given (if not, sum() has been given invalid arguments)
512514
# - first argument is a Generator (there is no benefit to optimizing the performance of eg.
513515
# sum([1, 2, 3]), so non-Generator Iterables are not handled)
514-
if not (
515-
len(expr.args) in (1, 2)
516-
and expr.arg_kinds[0] == ARG_POS
517-
and isinstance(expr.args[0], GeneratorExpr)
518-
):
516+
if not (len(expr.args) in (1, 2) and expr.arg_kinds[0] == ARG_POS):
517+
return None
518+
519+
arg = expr.args[0]
520+
if not isinstance(arg, GeneratorExpr) and not _is_supported_forloop_iter(builder, arg):
519521
return None
520522

521523
# handle 'start' argument, if given
@@ -527,21 +529,60 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V
527529
else:
528530
start_expr = IntExpr(0)
529531

530-
gen_expr = expr.args[0]
531-
target_type = builder.node_type(expr)
532-
retval = Register(target_type)
533-
builder.assign(retval, builder.coerce(builder.accept(start_expr), target_type, -1), -1)
532+
item_type = builder._analyze_iterable_item_type(arg)
533+
item_rtype = builder.type_to_rtype(item_type)
534+
start_rtype = builder.node_type(start_expr)
535+
536+
if item_rtype is start_rtype:
537+
acc_rtype = item_rtype
538+
elif is_float_rprimitive(item_rtype) and is_int_rprimitive(start_rtype):
539+
acc_rtype = float_rprimitive
540+
elif is_bool_rprimitive(item_rtype) and is_int_rprimitive(start_rtype):
541+
acc_rtype = int_rprimitive
542+
elif is_object_rprimitive(item_rtype) and is_int_rprimitive(start_rtype):
543+
acc_rtype = object_rprimitive
544+
545+
else:
546+
# escape hatch, maybe figure out a better way to handle this whole block
547+
# seeking ideas in review
548+
return None
534549

535-
def gen_inner_stmts() -> None:
536-
call_expr = builder.accept(gen_expr.left_expr)
537-
builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1)
550+
retval = Register(acc_rtype)
551+
builder.assign(retval, builder.coerce(builder.accept(start_expr), acc_rtype, -1), -1)
538552

539-
loop_params = list(
540-
zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists, gen_expr.is_async)
541-
)
542-
comprehension_helper(builder, loop_params, gen_inner_stmts, gen_expr.line)
553+
if isinstance(arg, GeneratorExpr):
554+
def gen_inner_stmts() -> None:
555+
call_expr = builder.accept(arg.left_expr)
556+
builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1)
543557

544-
return retval
558+
loop_params = list(
559+
zip(arg.indices, arg.sequences, arg.condlists, arg.is_async)
560+
)
561+
comprehension_helper(builder, loop_params, gen_inner_stmts, arg.line)
562+
563+
return retval
564+
565+
else:
566+
index_name = "__mypyc_sum_item__"
567+
568+
def body_insts() -> None:
569+
total = builder.binary_op(retval, builder.read(index_reg), "+", expr.line)
570+
builder.assign(retval, total, expr.line)
571+
572+
index_type = builder._analyze_iterable_item_type(arg)
573+
index = _create_iterable_lexpr(index_name, index_type)
574+
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type))
575+
576+
for_loop_helper(
577+
builder,
578+
index,
579+
arg,
580+
body_insts,
581+
None,
582+
is_async=False,
583+
line=expr.line,
584+
)
585+
return retval
545586

546587

547588
@specialize_function("dataclasses.field")

mypyc/test-data/irbuild-basic.test

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2995,6 +2995,12 @@ from typing import Callable, Iterable
29952995
def call_sum(l: Iterable[int], comparison: Callable[[int], bool]) -> int:
29962996
return sum(comparison(x) for x in l)
29972997

2998+
def call_sum_helper(l: Iterable[int], comparison: Callable[[int], bool]):
2999+
return sum([comparison(x) for x in l])
3000+
3001+
def call_sum_helper_start(l: Iterable[int]):
3002+
return sum([str(i) for i in l], "")
3003+
29983004
[out]
29993005
def call_sum(l, comparison):
30003006
l, comparison :: object
@@ -3033,6 +3039,121 @@ L4:
30333039
r12 = CPy_NoErrOccurred()
30343040
L5:
30353041
return r0
3042+
def call_sum_helper(l, comparison):
3043+
l, comparison :: object
3044+
r0 :: int
3045+
r1 :: list
3046+
r2, r3 :: object
3047+
r4, x :: int
3048+
r5 :: object
3049+
r6 :: object[1]
3050+
r7 :: object_ptr
3051+
r8 :: object
3052+
r9 :: bool
3053+
r10 :: object
3054+
r11 :: i32
3055+
r12, r13 :: bit
3056+
r14, r15 :: native_int
3057+
r16 :: bit
3058+
r17 :: object
3059+
r18, __mypyc_sum_item__, r19 :: bool
3060+
r20, r21 :: int
3061+
r22 :: native_int
3062+
r23 :: object
3063+
L0:
3064+
r0 = 0
3065+
r1 = PyList_New(0)
3066+
r2 = PyObject_GetIter(l)
3067+
L1:
3068+
r3 = PyIter_Next(r2)
3069+
if is_error(r3) goto L4 else goto L2
3070+
L2:
3071+
r4 = unbox(int, r3)
3072+
x = r4
3073+
r5 = box(int, x)
3074+
r6 = [r5]
3075+
r7 = load_address r6
3076+
r8 = PyObject_Vectorcall(comparison, r7, 1, 0)
3077+
keep_alive r5
3078+
r9 = unbox(bool, r8)
3079+
r10 = box(bool, r9)
3080+
r11 = PyList_Append(r1, r10)
3081+
r12 = r11 >= 0 :: signed
3082+
L3:
3083+
goto L1
3084+
L4:
3085+
r13 = CPy_NoErrOccurred()
3086+
L5:
3087+
r14 = 0
3088+
L6:
3089+
r15 = var_object_size r1
3090+
r16 = r14 < r15 :: signed
3091+
if r16 goto L7 else goto L9 :: bool
3092+
L7:
3093+
r17 = list_get_item_unsafe r1, r14
3094+
r18 = unbox(bool, r17)
3095+
__mypyc_sum_item__ = r18
3096+
r19 = __mypyc_sum_item__ << 1
3097+
r20 = extend r19: builtins.bool to builtins.int
3098+
r21 = CPyTagged_Add(r0, r20)
3099+
r0 = r21
3100+
L8:
3101+
r22 = r14 + 1
3102+
r14 = r22
3103+
goto L6
3104+
L9:
3105+
r23 = box(int, r0)
3106+
return r23
3107+
def call_sum_helper_start(l):
3108+
l :: object
3109+
r0, r1 :: str
3110+
r2 :: list
3111+
r3, r4 :: object
3112+
r5, i :: int
3113+
r6 :: str
3114+
r7 :: i32
3115+
r8, r9 :: bit
3116+
r10, r11 :: native_int
3117+
r12 :: bit
3118+
r13 :: object
3119+
r14, __mypyc_sum_item__, r15 :: str
3120+
r16 :: native_int
3121+
L0:
3122+
r0 = ''
3123+
r1 = r0
3124+
r2 = PyList_New(0)
3125+
r3 = PyObject_GetIter(l)
3126+
L1:
3127+
r4 = PyIter_Next(r3)
3128+
if is_error(r4) goto L4 else goto L2
3129+
L2:
3130+
r5 = unbox(int, r4)
3131+
i = r5
3132+
r6 = CPyTagged_Str(i)
3133+
r7 = PyList_Append(r2, r6)
3134+
r8 = r7 >= 0 :: signed
3135+
L3:
3136+
goto L1
3137+
L4:
3138+
r9 = CPy_NoErrOccurred()
3139+
L5:
3140+
r10 = 0
3141+
L6:
3142+
r11 = var_object_size r2
3143+
r12 = r10 < r11 :: signed
3144+
if r12 goto L7 else goto L9 :: bool
3145+
L7:
3146+
r13 = list_get_item_unsafe r2, r10
3147+
r14 = cast(str, r13)
3148+
__mypyc_sum_item__ = r14
3149+
r15 = PyUnicode_Concat(r1, __mypyc_sum_item__)
3150+
r1 = r15
3151+
L8:
3152+
r16 = r10 + 1
3153+
r10 = r16
3154+
goto L6
3155+
L9:
3156+
return r1
30363157

30373158
[case testSetAttr1]
30383159
from typing import Any, Dict, List

0 commit comments

Comments
 (0)