Skip to content

Commit dbf1d6f

Browse files
Mark map/accumulate iterators exhausted when the user callback raises StopIteration
1 parent 3706ef6 commit dbf1d6f

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

Modules/itertoolsmodule.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2983,6 +2983,7 @@ typedef struct {
29832983
PyObject *it;
29842984
PyObject *binop;
29852985
PyObject *initial;
2986+
int finished;
29862987
itertools_state *state;
29872988
} accumulateobject;
29882989

@@ -3024,6 +3025,7 @@ itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable,
30243025
lz->total = NULL;
30253026
lz->it = it;
30263027
lz->initial = Py_XNewRef(initial);
3028+
lz->finished = 0;
30273029
lz->state = find_state_by_type(type);
30283030
return (PyObject *)lz;
30293031
}
@@ -3060,6 +3062,10 @@ accumulate_next(PyObject *op)
30603062
accumulateobject *lz = accumulateobject_CAST(op);
30613063
PyObject *val, *newtotal;
30623064

3065+
if (lz->finished) {
3066+
return NULL;
3067+
}
3068+
30633069
if (lz->initial != Py_None) {
30643070
lz->total = lz->initial;
30653071
lz->initial = Py_NewRef(Py_None);
@@ -3079,8 +3085,12 @@ accumulate_next(PyObject *op)
30793085
else
30803086
newtotal = PyObject_CallFunctionObjArgs(lz->binop, lz->total, val, NULL);
30813087
Py_DECREF(val);
3082-
if (newtotal == NULL)
3088+
if (newtotal == NULL) {
3089+
if (lz->binop != NULL && PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration)) {
3090+
lz->finished = 1;
3091+
}
30833092
return NULL;
3093+
}
30843094

30853095
Py_INCREF(newtotal);
30863096
Py_SETREF(lz->total, newtotal);

Python/bltinmodule.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,6 +1354,7 @@ typedef struct {
13541354
PyObject *iters;
13551355
PyObject *func;
13561356
int strict;
1357+
int finished;
13571358
} mapobject;
13581359

13591360
#define _mapobject_CAST(op) ((mapobject *)(op))
@@ -1411,6 +1412,7 @@ map_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
14111412
func = PyTuple_GET_ITEM(args, 0);
14121413
lz->func = Py_NewRef(func);
14131414
lz->strict = strict;
1415+
lz->finished = 0;
14141416

14151417
return (PyObject *)lz;
14161418
}
@@ -1456,6 +1458,7 @@ map_vectorcall(PyObject *type, PyObject * const*args,
14561458
lz->iters = iters;
14571459
lz->func = Py_NewRef(args[0]);
14581460
lz->strict = 0;
1461+
lz->finished = 0;
14591462

14601463
return (PyObject *)lz;
14611464
}
@@ -1489,6 +1492,10 @@ map_next(PyObject *self)
14891492
PyObject *result = NULL;
14901493
PyThreadState *tstate = _PyThreadState_GET();
14911494

1495+
if (lz->finished) {
1496+
return NULL;
1497+
}
1498+
14921499
const Py_ssize_t niters = PyTuple_GET_SIZE(lz->iters);
14931500
if (niters <= (Py_ssize_t)Py_ARRAY_LENGTH(small_stack)) {
14941501
stack = small_stack;
@@ -1516,6 +1523,11 @@ map_next(PyObject *self)
15161523
}
15171524

15181525
result = _PyObject_VectorcallTstate(tstate, lz->func, stack, nargs, NULL);
1526+
if (result == NULL && PyErr_Occurred()) {
1527+
if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
1528+
lz->finished = 1;
1529+
}
1530+
}
15191531

15201532
exit:
15211533
for (i=0; i < nargs; i++) {

0 commit comments

Comments
 (0)