55
55
bytes_rprimitive ,
56
56
c_int_rprimitive ,
57
57
dict_rprimitive ,
58
+ float_rprimitive ,
58
59
int16_rprimitive ,
59
60
int32_rprimitive ,
60
61
int64_rprimitive ,
68
69
is_int64_rprimitive ,
69
70
is_int_rprimitive ,
70
71
is_list_rprimitive ,
72
+ is_object_rprimitive ,
71
73
is_uint8_rprimitive ,
72
74
list_rprimitive ,
73
75
object_rprimitive ,
@@ -511,11 +513,11 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V
511
513
# - only one or two arguments given (if not, sum() has been given invalid arguments)
512
514
# - first argument is a Generator (there is no benefit to optimizing the performance of eg.
513
515
# sum([1, 2, 3]), so non-Generator Iterables are not handled)
514
- if not (
515
- len ( expr . args ) in ( 1 , 2 )
516
- and expr . arg_kinds [ 0 ] == ARG_POS
517
- and isinstance ( expr .args [0 ], GeneratorExpr )
518
- ):
516
+ if not (len ( expr . args ) in ( 1 , 2 ) and expr . arg_kinds [ 0 ] == ARG_POS ):
517
+ return None
518
+
519
+ arg = expr .args [0 ]
520
+ if not isinstance ( arg , GeneratorExpr ) and not _is_supported_forloop_iter ( builder , arg ):
519
521
return None
520
522
521
523
# handle 'start' argument, if given
@@ -527,21 +529,60 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V
527
529
else :
528
530
start_expr = IntExpr (0 )
529
531
530
- gen_expr = expr .args [0 ]
531
- target_type = builder .node_type (expr )
532
- retval = Register (target_type )
533
- builder .assign (retval , builder .coerce (builder .accept (start_expr ), target_type , - 1 ), - 1 )
532
+ item_type = builder ._analyze_iterable_item_type (arg )
533
+ item_rtype = builder .type_to_rtype (item_type )
534
+ start_rtype = builder .node_type (start_expr )
535
+
536
+ if item_rtype is start_rtype :
537
+ acc_rtype = item_rtype
538
+ elif is_float_rprimitive (item_rtype ) and is_int_rprimitive (start_rtype ):
539
+ acc_rtype = float_rprimitive
540
+ elif is_bool_rprimitive (item_rtype ) and is_int_rprimitive (start_rtype ):
541
+ acc_rtype = int_rprimitive
542
+ elif is_object_rprimitive (item_rtype ) and is_int_rprimitive (start_rtype ):
543
+ acc_rtype = object_rprimitive
544
+
545
+ else :
546
+ # escape hatch, maybe figure out a better way to handle this whole block
547
+ # seeking ideas in review
548
+ return None
534
549
535
- def gen_inner_stmts () -> None :
536
- call_expr = builder .accept (gen_expr .left_expr )
537
- builder .assign (retval , builder .binary_op (retval , call_expr , "+" , - 1 ), - 1 )
550
+ retval = Register (acc_rtype )
551
+ builder .assign (retval , builder .coerce (builder .accept (start_expr ), acc_rtype , - 1 ), - 1 )
538
552
539
- loop_params = list (
540
- zip ( gen_expr . indices , gen_expr . sequences , gen_expr . condlists , gen_expr . is_async )
541
- )
542
- comprehension_helper ( builder , loop_params , gen_inner_stmts , gen_expr . line )
553
+ if isinstance ( arg , GeneratorExpr ):
554
+ def gen_inner_stmts () -> None :
555
+ call_expr = builder . accept ( arg . left_expr )
556
+ builder . assign ( retval , builder . binary_op ( retval , call_expr , "+" , - 1 ), - 1 )
543
557
544
- return retval
558
+ loop_params = list (
559
+ zip (arg .indices , arg .sequences , arg .condlists , arg .is_async )
560
+ )
561
+ comprehension_helper (builder , loop_params , gen_inner_stmts , arg .line )
562
+
563
+ return retval
564
+
565
+ else :
566
+ index_name = "__mypyc_sum_item__"
567
+
568
+ def body_insts () -> None :
569
+ total = builder .binary_op (retval , builder .read (index_reg ), "+" , expr .line )
570
+ builder .assign (retval , total , expr .line )
571
+
572
+ index_type = builder ._analyze_iterable_item_type (arg )
573
+ index = _create_iterable_lexpr (index_name , index_type )
574
+ index_reg = builder .add_local_reg (index .node , builder .type_to_rtype (index_type ))
575
+
576
+ for_loop_helper (
577
+ builder ,
578
+ index ,
579
+ arg ,
580
+ body_insts ,
581
+ None ,
582
+ is_async = False ,
583
+ line = expr .line ,
584
+ )
585
+ return retval
545
586
546
587
547
588
@specialize_function ("dataclasses.field" )
0 commit comments