Skip to content

Commit c19ad9f

Browse files
committed
feat: specialize sum from for helper
1 parent 41c7c9b commit c19ad9f

File tree

2 files changed

+170
-17
lines changed

2 files changed

+170
-17
lines changed

mypyc/irbuild/specialize.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
bytes_rprimitive,
5757
c_int_rprimitive,
5858
dict_rprimitive,
59+
float_rprimitive,
5960
int16_rprimitive,
6061
int32_rprimitive,
6162
int64_rprimitive,
@@ -69,6 +70,7 @@
6970
is_int64_rprimitive,
7071
is_int_rprimitive,
7172
is_list_rprimitive,
73+
is_object_rprimitive,
7274
is_uint8_rprimitive,
7375
list_rprimitive,
7476
object_rprimitive,
@@ -514,11 +516,11 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V
514516
# - only one or two arguments given (if not, sum() has been given invalid arguments)
515517
# - first argument is a Generator (there is no benefit to optimizing the performance of eg.
516518
# sum([1, 2, 3]), so non-Generator Iterables are not handled)
517-
if not (
518-
len(expr.args) in (1, 2)
519-
and expr.arg_kinds[0] == ARG_POS
520-
and isinstance(expr.args[0], GeneratorExpr)
521-
):
519+
if not (len(expr.args) in (1, 2) and expr.arg_kinds[0] == ARG_POS):
520+
return None
521+
522+
arg = expr.args[0]
523+
if not isinstance(arg, GeneratorExpr) and not _is_supported_forloop_iter(builder, arg):
522524
return None
523525

524526
# handle 'start' argument, if given
@@ -530,21 +532,51 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V
530532
else:
531533
start_expr = IntExpr(0)
532534

533-
gen_expr = expr.args[0]
534-
target_type = builder.node_type(expr)
535-
retval = Register(target_type)
536-
builder.assign(retval, builder.coerce(builder.accept(start_expr), target_type, -1), -1)
535+
item_type = builder._analyze_iterable_item_type(arg)
536+
item_rtype = builder.type_to_rtype(item_type)
537+
start_rtype = builder.node_type(start_expr)
537538

538-
def gen_inner_stmts() -> None:
539-
call_expr = builder.accept(gen_expr.left_expr)
540-
builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1)
539+
if item_rtype is start_rtype:
540+
acc_rtype = item_rtype
541+
elif is_float_rprimitive(item_rtype) and is_int_rprimitive(start_rtype):
542+
acc_rtype = float_rprimitive
543+
elif is_bool_rprimitive(item_rtype) and is_int_rprimitive(start_rtype):
544+
acc_rtype = int_rprimitive
545+
elif is_object_rprimitive(item_rtype) and is_int_rprimitive(start_rtype):
546+
acc_rtype = object_rprimitive
541547

542-
loop_params = list(
543-
zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists, gen_expr.is_async)
544-
)
545-
comprehension_helper(builder, loop_params, gen_inner_stmts, gen_expr.line)
548+
else:
549+
# escape hatch, maybe figure out a better way to handle this whole block
550+
# seeking ideas in review
551+
return None
546552

547-
return retval
553+
retval = Register(acc_rtype)
554+
builder.assign(retval, builder.coerce(builder.accept(start_expr), acc_rtype, -1), -1)
555+
556+
if isinstance(arg, GeneratorExpr):
557+
558+
def gen_inner_stmts() -> None:
559+
call_expr = builder.accept(arg.left_expr)
560+
builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1)
561+
562+
loop_params = list(zip(arg.indices, arg.sequences, arg.condlists, arg.is_async))
563+
comprehension_helper(builder, loop_params, gen_inner_stmts, arg.line)
564+
565+
return retval
566+
567+
else:
568+
index_name = "__mypyc_sum_item__"
569+
570+
def body_insts() -> None:
571+
total = builder.binary_op(retval, builder.read(index_reg), "+", expr.line)
572+
builder.assign(retval, total, expr.line)
573+
574+
index_type = builder._analyze_iterable_item_type(arg)
575+
index = _create_iterable_lexpr(index_name, index_type)
576+
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type]
577+
578+
for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line)
579+
return retval
548580

549581

550582
@specialize_function("dataclasses.field")

mypyc/test-data/irbuild-basic.test

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2995,6 +2995,12 @@ from typing import Callable, Iterable
29952995
def call_sum(l: Iterable[int], comparison: Callable[[int], bool]) -> int:
29962996
return sum(comparison(x) for x in l)
29972997

2998+
def call_sum_helper(l: Iterable[int], comparison: Callable[[int], bool]):
2999+
return sum([comparison(x) for x in l])
3000+
3001+
def call_sum_helper_start(l: Iterable[int]):
3002+
return sum([str(i) for i in l], "")
3003+
29983004
[out]
29993005
def call_sum(l, comparison):
30003006
l, comparison :: object
@@ -3033,6 +3039,121 @@ L4:
30333039
r12 = CPy_NoErrOccurred()
30343040
L5:
30353041
return r0
3042+
def call_sum_helper(l, comparison):
3043+
l, comparison :: object
3044+
r0 :: int
3045+
r1 :: list
3046+
r2, r3 :: object
3047+
r4, x :: int
3048+
r5 :: object
3049+
r6 :: object[1]
3050+
r7 :: object_ptr
3051+
r8 :: object
3052+
r9 :: bool
3053+
r10 :: object
3054+
r11 :: i32
3055+
r12, r13 :: bit
3056+
r14, r15 :: native_int
3057+
r16 :: bit
3058+
r17 :: object
3059+
r18, __mypyc_sum_item__, r19 :: bool
3060+
r20, r21 :: int
3061+
r22 :: native_int
3062+
r23 :: object
3063+
L0:
3064+
r0 = 0
3065+
r1 = PyList_New(0)
3066+
r2 = PyObject_GetIter(l)
3067+
L1:
3068+
r3 = PyIter_Next(r2)
3069+
if is_error(r3) goto L4 else goto L2
3070+
L2:
3071+
r4 = unbox(int, r3)
3072+
x = r4
3073+
r5 = box(int, x)
3074+
r6 = [r5]
3075+
r7 = load_address r6
3076+
r8 = PyObject_Vectorcall(comparison, r7, 1, 0)
3077+
keep_alive r5
3078+
r9 = unbox(bool, r8)
3079+
r10 = box(bool, r9)
3080+
r11 = PyList_Append(r1, r10)
3081+
r12 = r11 >= 0 :: signed
3082+
L3:
3083+
goto L1
3084+
L4:
3085+
r13 = CPy_NoErrOccurred()
3086+
L5:
3087+
r14 = 0
3088+
L6:
3089+
r15 = var_object_size r1
3090+
r16 = r14 < r15 :: signed
3091+
if r16 goto L7 else goto L9 :: bool
3092+
L7:
3093+
r17 = list_get_item_unsafe r1, r14
3094+
r18 = unbox(bool, r17)
3095+
__mypyc_sum_item__ = r18
3096+
r19 = __mypyc_sum_item__ << 1
3097+
r20 = extend r19: builtins.bool to builtins.int
3098+
r21 = CPyTagged_Add(r0, r20)
3099+
r0 = r21
3100+
L8:
3101+
r22 = r14 + 1
3102+
r14 = r22
3103+
goto L6
3104+
L9:
3105+
r23 = box(int, r0)
3106+
return r23
3107+
def call_sum_helper_start(l):
3108+
l :: object
3109+
r0, r1 :: str
3110+
r2 :: list
3111+
r3, r4 :: object
3112+
r5, i :: int
3113+
r6 :: str
3114+
r7 :: i32
3115+
r8, r9 :: bit
3116+
r10, r11 :: native_int
3117+
r12 :: bit
3118+
r13 :: object
3119+
r14, __mypyc_sum_item__, r15 :: str
3120+
r16 :: native_int
3121+
L0:
3122+
r0 = ''
3123+
r1 = r0
3124+
r2 = PyList_New(0)
3125+
r3 = PyObject_GetIter(l)
3126+
L1:
3127+
r4 = PyIter_Next(r3)
3128+
if is_error(r4) goto L4 else goto L2
3129+
L2:
3130+
r5 = unbox(int, r4)
3131+
i = r5
3132+
r6 = CPyTagged_Str(i)
3133+
r7 = PyList_Append(r2, r6)
3134+
r8 = r7 >= 0 :: signed
3135+
L3:
3136+
goto L1
3137+
L4:
3138+
r9 = CPy_NoErrOccurred()
3139+
L5:
3140+
r10 = 0
3141+
L6:
3142+
r11 = var_object_size r2
3143+
r12 = r10 < r11 :: signed
3144+
if r12 goto L7 else goto L9 :: bool
3145+
L7:
3146+
r13 = list_get_item_unsafe r2, r10
3147+
r14 = cast(str, r13)
3148+
__mypyc_sum_item__ = r14
3149+
r15 = PyUnicode_Concat(r1, __mypyc_sum_item__)
3150+
r1 = r15
3151+
L8:
3152+
r16 = r10 + 1
3153+
r10 = r16
3154+
goto L6
3155+
L9:
3156+
return r1
30363157

30373158
[case testSetAttr1]
30383159
from typing import Any, Dict, List

0 commit comments

Comments
 (0)