Skip to content

Commit a866af7

Browse files
committed
specialize sum from for helper
1 parent 92818b2 commit a866af7

File tree

2 files changed

+179
-17
lines changed

2 files changed

+179
-17
lines changed

mypyc/irbuild/specialize.py

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
bytes_rprimitive,
5858
c_int_rprimitive,
5959
dict_rprimitive,
60+
float_rprimitive,
6061
int16_rprimitive,
6162
int32_rprimitive,
6263
int64_rprimitive,
@@ -70,6 +71,7 @@
7071
is_int64_rprimitive,
7172
is_int_rprimitive,
7273
is_list_rprimitive,
74+
is_object_rprimitive,
7375
is_uint8_rprimitive,
7476
list_rprimitive,
7577
object_rprimitive,
@@ -515,11 +517,11 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V
515517
# - only one or two arguments given (if not, sum() has been given invalid arguments)
516518
# - first argument is a Generator (there is no benefit to optimizing the performance of eg.
517519
# sum([1, 2, 3]), so non-Generator Iterables are not handled)
518-
if not (
519-
len(expr.args) in (1, 2)
520-
and expr.arg_kinds[0] == ARG_POS
521-
and isinstance(expr.args[0], GeneratorExpr)
522-
):
520+
if not (len(expr.args) in (1, 2) and expr.arg_kinds[0] == ARG_POS):
521+
return None
522+
523+
arg = expr.args[0]
524+
if not isinstance(arg, GeneratorExpr) and not _is_supported_forloop_iter(builder, arg):
523525
return None
524526

525527
# handle 'start' argument, if given
@@ -531,21 +533,60 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V
531533
else:
532534
start_expr = IntExpr(0)
533535

534-
gen_expr = expr.args[0]
535-
target_type = builder.node_type(expr)
536-
retval = Register(target_type)
537-
builder.assign(retval, builder.coerce(builder.accept(start_expr), target_type, -1), -1)
536+
item_type = builder._analyze_iterable_item_type(arg)
537+
item_rtype = builder.type_to_rtype(item_type)
538+
start_rtype = builder.node_type(start_expr)
539+
540+
if item_rtype is start_rtype:
541+
acc_rtype = item_rtype
542+
elif is_float_rprimitive(item_rtype) and is_int_rprimitive(start_rtype):
543+
acc_rtype = float_rprimitive
544+
elif is_bool_rprimitive(item_rtype) and is_int_rprimitive(start_rtype):
545+
acc_rtype = int_rprimitive
546+
elif is_object_rprimitive(item_rtype) and is_int_rprimitive(start_rtype):
547+
acc_rtype = object_rprimitive
548+
549+
else:
550+
# escape hatch, maybe figure out a better way to handle this whole block
551+
# seeking ideas in review
552+
return None
538553

539-
def gen_inner_stmts() -> None:
540-
call_expr = builder.accept(gen_expr.left_expr)
541-
builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1)
554+
retval = Register(acc_rtype)
555+
builder.assign(retval, builder.coerce(builder.accept(start_expr), acc_rtype, -1), -1)
542556

543-
loop_params = list(
544-
zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists, gen_expr.is_async)
545-
)
546-
comprehension_helper(builder, loop_params, gen_inner_stmts, gen_expr.line)
557+
if isinstance(arg, GeneratorExpr):
558+
def gen_inner_stmts() -> None:
559+
call_expr = builder.accept(arg.left_expr)
560+
builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1)
547561

548-
return retval
562+
loop_params = list(
563+
zip(arg.indices, arg.sequences, arg.condlists, arg.is_async)
564+
)
565+
comprehension_helper(builder, loop_params, gen_inner_stmts, arg.line)
566+
567+
return retval
568+
569+
else:
570+
index_name = "__mypyc_sum_item__"
571+
572+
def body_insts() -> None:
573+
total = builder.binary_op(retval, builder.read(index_reg), "+", expr.line)
574+
builder.assign(retval, total, expr.line)
575+
576+
index_type = builder._analyze_iterable_item_type(arg)
577+
index = _create_iterable_lexpr(index_name, index_type)
578+
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type))
579+
580+
for_loop_helper(
581+
builder,
582+
index,
583+
arg,
584+
body_insts,
585+
None,
586+
is_async=False,
587+
line=expr.line,
588+
)
589+
return retval
549590

550591

551592
@specialize_function("dataclasses.field")

mypyc/test-data/irbuild-basic.test

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2995,6 +2995,12 @@ from typing import Callable, Iterable
29952995
def call_sum(l: Iterable[int], comparison: Callable[[int], bool]) -> int:
29962996
return sum(comparison(x) for x in l)
29972997

2998+
def call_sum_helper(l: Iterable[int], comparison: Callable[[int], bool]):
2999+
return sum([comparison(x) for x in l])
3000+
3001+
def call_sum_helper_start(l: Iterable[int]):
3002+
return sum([str(i) for i in l], "")
3003+
29983004
[out]
29993005
def call_sum(l, comparison):
30003006
l, comparison :: object
@@ -3033,6 +3039,121 @@ L4:
30333039
r12 = CPy_NoErrOccurred()
30343040
L5:
30353041
return r0
3042+
def call_sum_helper(l, comparison):
3043+
l, comparison :: object
3044+
r0 :: int
3045+
r1 :: list
3046+
r2, r3 :: object
3047+
r4, x :: int
3048+
r5 :: object
3049+
r6 :: object[1]
3050+
r7 :: object_ptr
3051+
r8 :: object
3052+
r9 :: bool
3053+
r10 :: object
3054+
r11 :: i32
3055+
r12, r13 :: bit
3056+
r14, r15 :: native_int
3057+
r16 :: bit
3058+
r17 :: object
3059+
r18, __mypyc_sum_item__, r19 :: bool
3060+
r20, r21 :: int
3061+
r22 :: native_int
3062+
r23 :: object
3063+
L0:
3064+
r0 = 0
3065+
r1 = PyList_New(0)
3066+
r2 = PyObject_GetIter(l)
3067+
L1:
3068+
r3 = PyIter_Next(r2)
3069+
if is_error(r3) goto L4 else goto L2
3070+
L2:
3071+
r4 = unbox(int, r3)
3072+
x = r4
3073+
r5 = box(int, x)
3074+
r6 = [r5]
3075+
r7 = load_address r6
3076+
r8 = PyObject_Vectorcall(comparison, r7, 1, 0)
3077+
keep_alive r5
3078+
r9 = unbox(bool, r8)
3079+
r10 = box(bool, r9)
3080+
r11 = PyList_Append(r1, r10)
3081+
r12 = r11 >= 0 :: signed
3082+
L3:
3083+
goto L1
3084+
L4:
3085+
r13 = CPy_NoErrOccurred()
3086+
L5:
3087+
r14 = 0
3088+
L6:
3089+
r15 = var_object_size r1
3090+
r16 = r14 < r15 :: signed
3091+
if r16 goto L7 else goto L9 :: bool
3092+
L7:
3093+
r17 = list_get_item_unsafe r1, r14
3094+
r18 = unbox(bool, r17)
3095+
__mypyc_sum_item__ = r18
3096+
r19 = __mypyc_sum_item__ << 1
3097+
r20 = extend r19: builtins.bool to builtins.int
3098+
r21 = CPyTagged_Add(r0, r20)
3099+
r0 = r21
3100+
L8:
3101+
r22 = r14 + 1
3102+
r14 = r22
3103+
goto L6
3104+
L9:
3105+
r23 = box(int, r0)
3106+
return r23
3107+
def call_sum_helper_start(l):
3108+
l :: object
3109+
r0, r1 :: str
3110+
r2 :: list
3111+
r3, r4 :: object
3112+
r5, i :: int
3113+
r6 :: str
3114+
r7 :: i32
3115+
r8, r9 :: bit
3116+
r10, r11 :: native_int
3117+
r12 :: bit
3118+
r13 :: object
3119+
r14, __mypyc_sum_item__, r15 :: str
3120+
r16 :: native_int
3121+
L0:
3122+
r0 = ''
3123+
r1 = r0
3124+
r2 = PyList_New(0)
3125+
r3 = PyObject_GetIter(l)
3126+
L1:
3127+
r4 = PyIter_Next(r3)
3128+
if is_error(r4) goto L4 else goto L2
3129+
L2:
3130+
r5 = unbox(int, r4)
3131+
i = r5
3132+
r6 = CPyTagged_Str(i)
3133+
r7 = PyList_Append(r2, r6)
3134+
r8 = r7 >= 0 :: signed
3135+
L3:
3136+
goto L1
3137+
L4:
3138+
r9 = CPy_NoErrOccurred()
3139+
L5:
3140+
r10 = 0
3141+
L6:
3142+
r11 = var_object_size r2
3143+
r12 = r10 < r11 :: signed
3144+
if r12 goto L7 else goto L9 :: bool
3145+
L7:
3146+
r13 = list_get_item_unsafe r2, r10
3147+
r14 = cast(str, r13)
3148+
__mypyc_sum_item__ = r14
3149+
r15 = PyUnicode_Concat(r1, __mypyc_sum_item__)
3150+
r1 = r15
3151+
L8:
3152+
r16 = r10 + 1
3153+
r10 = r16
3154+
goto L6
3155+
L9:
3156+
return r1
30363157

30373158
[case testSetAttr1]
30383159
from typing import Any, Dict, List

0 commit comments

Comments
 (0)