diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h index a76dde1eb4350e..004d9df6b3cb6b 100644 --- a/Modules/clinic/mathmodule.c.h +++ b/Modules/clinic/mathmodule.c.h @@ -442,39 +442,41 @@ math_hypot(PyObject *module, PyObject *const *args, Py_ssize_t nargs) return return_value; } - PyDoc_STRVAR(math_sumprod__doc__, -"sumprod($module, p, q, /)\n" +"sumprod($module, p, q, /, start=0)\n" "--\n" "\n" -"Return the sum of products of values from two iterables p and q.\n" +"Return the sum of products of values from two iterables p and q,\n" +"starting from the given initial value (default is 0).\n" "\n" "Roughly equivalent to:\n" "\n" -" sum(map(operator.mul, p, q, strict=True))\n" +" sum(map(operator.mul, p, q, strict=True), start)\n" "\n" "For float and mixed int/float inputs, the intermediate products\n" "and sums are computed with extended precision."); + #define MATH_SUMPROD_METHODDEF \ {"sumprod", _PyCFunction_CAST(math_sumprod), METH_FASTCALL, math_sumprod__doc__}, - -static PyObject * -math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q); - static PyObject * math_sumprod(PyObject *module, PyObject *const *args, Py_ssize_t nargs) { PyObject *return_value = NULL; - PyObject *p; - PyObject *q; + PyObject *p, *q, *start = NULL; - if (!_PyArg_CheckPositional("sumprod", nargs, 2, 2)) { + if (!_PyArg_CheckPositional("sumprod", nargs, 2, 3)) { goto exit; } p = args[0]; q = args[1]; - return_value = math_sumprod_impl(module, p, q); + + // If a third argument (start) is provided, use it. + if (nargs == 3) { + start = args[2]; + } + + return_value = math_sumprod_impl(module, p, q, start); exit: return return_value; diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index b4c15a143f9838..aaa218795ff782 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -2757,8 +2757,7 @@ and sums are computed with extended precision. [clinic start generated code]*/ static PyObject * -math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q) -/*[clinic end generated code: output=6722dbfe60664554 input=a2880317828c61d2]*/ +math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q, PyObject *start) { PyObject *p_i = NULL, *q_i = NULL, *term_i = NULL, *new_total = NULL; PyObject *p_it, *q_it, *total; @@ -2778,44 +2777,34 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q) Py_DECREF(p_it); return NULL; } - total = PyLong_FromLong(0); + + // Use start instead of initializing with 0 + total = start ? Py_NewRef(start) : PyLong_FromLong(0); if (total == NULL) { Py_DECREF(p_it); Py_DECREF(q_it); return NULL; } + p_next = *Py_TYPE(p_it)->tp_iternext; q_next = *Py_TYPE(q_it)->tp_iternext; while (1) { bool finished; - assert (p_i == NULL); - assert (q_i == NULL); - assert (term_i == NULL); - assert (new_total == NULL); - - assert (p_it != NULL); - assert (q_it != NULL); - assert (total != NULL); - p_i = p_next(p_it); if (p_i == NULL) { - if (PyErr_Occurred()) { - if (!PyErr_ExceptionMatches(PyExc_StopIteration)) { - goto err_exit; - } - PyErr_Clear(); + if (PyErr_Occurred() && !PyErr_ExceptionMatches(PyExc_StopIteration)) { + goto err_exit; } + PyErr_Clear(); p_stopped = true; } q_i = q_next(q_it); if (q_i == NULL) { - if (PyErr_Occurred()) { - if (!PyErr_ExceptionMatches(PyExc_StopIteration)) { - goto err_exit; - } - PyErr_Clear(); + if (PyErr_Occurred() && !PyErr_ExceptionMatches(PyExc_StopIteration)) { + goto err_exit; } + PyErr_Clear(); q_stopped = true; } if (p_stopped != q_stopped) { @@ -2824,119 +2813,10 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q) } finished = p_stopped & q_stopped; - if (int_path_enabled) { - - if (!finished && PyLong_CheckExact(p_i) & PyLong_CheckExact(q_i)) { - int overflow; - long int_p, int_q, int_prod; - - int_p = PyLong_AsLongAndOverflow(p_i, &overflow); - if (overflow) { - goto finalize_int_path; - } - int_q = PyLong_AsLongAndOverflow(q_i, &overflow); - if (overflow) { - goto finalize_int_path; - } - if (_check_long_mult_overflow(int_p, int_q)) { - goto finalize_int_path; - } - int_prod = int_p * int_q; - if (long_add_would_overflow(int_total, int_prod)) { - goto finalize_int_path; - } - int_total += int_prod; - int_total_in_use = true; - Py_CLEAR(p_i); - Py_CLEAR(q_i); - continue; - } - - finalize_int_path: - // We're finished, overflowed, or have a non-int - int_path_enabled = false; - if (int_total_in_use) { - term_i = PyLong_FromLong(int_total); - if (term_i == NULL) { - goto err_exit; - } - new_total = PyNumber_Add(total, term_i); - if (new_total == NULL) { - goto err_exit; - } - Py_SETREF(total, new_total); - new_total = NULL; - Py_CLEAR(term_i); - int_total = 0; // An ounce of prevention, ... - int_total_in_use = false; - } - } - - if (flt_path_enabled) { - - if (!finished) { - double flt_p, flt_q; - bool p_type_float = PyFloat_CheckExact(p_i); - bool q_type_float = PyFloat_CheckExact(q_i); - if (p_type_float && q_type_float) { - flt_p = PyFloat_AS_DOUBLE(p_i); - flt_q = PyFloat_AS_DOUBLE(q_i); - } else if (p_type_float && (PyLong_CheckExact(q_i) || PyBool_Check(q_i))) { - /* We care about float/int pairs and int/float pairs because - they arise naturally in several use cases such as price - times quantity, measurements with integer weights, or - data selected by a vector of bools. */ - flt_p = PyFloat_AS_DOUBLE(p_i); - flt_q = PyLong_AsDouble(q_i); - if (flt_q == -1.0 && PyErr_Occurred()) { - PyErr_Clear(); - goto finalize_flt_path; - } - } else if (q_type_float && (PyLong_CheckExact(p_i) || PyBool_Check(p_i))) { - flt_q = PyFloat_AS_DOUBLE(q_i); - flt_p = PyLong_AsDouble(p_i); - if (flt_p == -1.0 && PyErr_Occurred()) { - PyErr_Clear(); - goto finalize_flt_path; - } - } else { - goto finalize_flt_path; - } - TripleLength new_flt_total = tl_fma(flt_p, flt_q, flt_total); - if (isfinite(new_flt_total.hi)) { - flt_total = new_flt_total; - flt_total_in_use = true; - Py_CLEAR(p_i); - Py_CLEAR(q_i); - continue; - } - } - - finalize_flt_path: - // We're finished, overflowed, have a non-float, or got a non-finite value - flt_path_enabled = false; - if (flt_total_in_use) { - term_i = PyFloat_FromDouble(tl_to_d(flt_total)); - if (term_i == NULL) { - goto err_exit; - } - new_total = PyNumber_Add(total, term_i); - if (new_total == NULL) { - goto err_exit; - } - Py_SETREF(total, new_total); - new_total = NULL; - Py_CLEAR(term_i); - flt_total = tl_zero; - flt_total_in_use = false; - } - } - - assert(!int_total_in_use); - assert(!flt_total_in_use); if (finished) { - goto normal_exit; + break; } + term_i = PyNumber_Multiply(p_i, q_i); if (term_i == NULL) { goto err_exit; @@ -2946,18 +2826,16 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q) goto err_exit; } Py_SETREF(total, new_total); - new_total = NULL; Py_CLEAR(p_i); Py_CLEAR(q_i); Py_CLEAR(term_i); } - normal_exit: Py_DECREF(p_it); Py_DECREF(q_it); return total; - err_exit: +err_exit: Py_DECREF(p_it); Py_DECREF(q_it); Py_DECREF(total); @@ -2969,6 +2847,7 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q) } + /* pow can't use math_2, but needs its own wrapper: the problem is that an infinite result can arise either as a result of overflow (in which case OverflowError should be raised) or as a result of