5656 bytes_rprimitive ,
5757 c_int_rprimitive ,
5858 dict_rprimitive ,
59+ float_rprimitive ,
5960 int16_rprimitive ,
6061 int32_rprimitive ,
6162 int64_rprimitive ,
6970 is_int64_rprimitive ,
7071 is_int_rprimitive ,
7172 is_list_rprimitive ,
73+ is_object_rprimitive ,
7274 is_uint8_rprimitive ,
7375 list_rprimitive ,
7476 object_rprimitive ,
@@ -514,11 +516,11 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V
514516 # - only one or two arguments given (if not, sum() has been given invalid arguments)
515517 # - first argument is a Generator (there is no benefit to optimizing the performance of eg.
516518 # sum([1, 2, 3]), so non-Generator Iterables are not handled)
517- if not (
518- len ( expr . args ) in ( 1 , 2 )
519- and expr . arg_kinds [ 0 ] == ARG_POS
520- and isinstance ( expr .args [0 ], GeneratorExpr )
521- ):
519+ if not (len ( expr . args ) in ( 1 , 2 ) and expr . arg_kinds [ 0 ] == ARG_POS ):
520+ return None
521+
522+ arg = expr .args [0 ]
523+ if not isinstance ( arg , GeneratorExpr ) and not _is_supported_forloop_iter ( builder , arg ):
522524 return None
523525
524526 # handle 'start' argument, if given
@@ -530,21 +532,51 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V
530532 else :
531533 start_expr = IntExpr (0 )
532534
533- gen_expr = expr .args [0 ]
534- target_type = builder .node_type (expr )
535- retval = Register (target_type )
536- builder .assign (retval , builder .coerce (builder .accept (start_expr ), target_type , - 1 ), - 1 )
535+ item_type = builder ._analyze_iterable_item_type (arg )
536+ item_rtype = builder .type_to_rtype (item_type )
537+ start_rtype = builder .node_type (start_expr )
537538
538- def gen_inner_stmts () -> None :
539- call_expr = builder .accept (gen_expr .left_expr )
540- builder .assign (retval , builder .binary_op (retval , call_expr , "+" , - 1 ), - 1 )
539+ if item_rtype is start_rtype :
540+ acc_rtype = item_rtype
541+ elif is_float_rprimitive (item_rtype ) and is_int_rprimitive (start_rtype ):
542+ acc_rtype = float_rprimitive
543+ elif is_bool_rprimitive (item_rtype ) and is_int_rprimitive (start_rtype ):
544+ acc_rtype = int_rprimitive
545+ elif is_object_rprimitive (item_rtype ) and is_int_rprimitive (start_rtype ):
546+ acc_rtype = object_rprimitive
541547
542- loop_params = list (
543- zip ( gen_expr . indices , gen_expr . sequences , gen_expr . condlists , gen_expr . is_async )
544- )
545- comprehension_helper ( builder , loop_params , gen_inner_stmts , gen_expr . line )
548+ else :
549+ # escape hatch, maybe figure out a better way to handle this whole block
550+ # seeking ideas in review
551+ return None
546552
547- return retval
553+ retval = Register (acc_rtype )
554+ builder .assign (retval , builder .coerce (builder .accept (start_expr ), acc_rtype , - 1 ), - 1 )
555+
556+ if isinstance (arg , GeneratorExpr ):
557+
558+ def gen_inner_stmts () -> None :
559+ call_expr = builder .accept (arg .left_expr )
560+ builder .assign (retval , builder .binary_op (retval , call_expr , "+" , - 1 ), - 1 )
561+
562+ loop_params = list (zip (arg .indices , arg .sequences , arg .condlists , arg .is_async ))
563+ comprehension_helper (builder , loop_params , gen_inner_stmts , arg .line )
564+
565+ return retval
566+
567+ else :
568+ index_name = "__mypyc_sum_item__"
569+
570+ def body_insts () -> None :
571+ total = builder .binary_op (retval , builder .read (index_reg ), "+" , expr .line )
572+ builder .assign (retval , total , expr .line )
573+
574+ index_type = builder ._analyze_iterable_item_type (arg )
575+ index = _create_iterable_lexpr (index_name , index_type )
576+ index_reg = builder .add_local_reg (index .node , builder .type_to_rtype (index_type )) # type: ignore [arg-type]
577+
578+ for_loop_helper (builder , index , arg , body_insts , None , is_async = False , line = expr .line )
579+ return retval
548580
549581
550582@specialize_function ("dataclasses.field" )
0 commit comments