Skip to content

Commit 62bf9c2

Browse files
committed
make cycle_next thread safe
1 parent 9e5cb51 commit 62bf9c2

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

Lib/test/test_free_threading/test_itertools_batched.py renamed to Lib/test/test_free_threading/test_itertools.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import unittest
22
from threading import Thread, Barrier
3-
from itertools import batched
3+
from itertools import batched, cycle
44
from test.support import threading_helper
55

66

77
threading_helper.requires_working_threading(module=True)
88

9-
class EnumerateThreading(unittest.TestCase):
9+
class ItertoolsThreading(unittest.TestCase):
1010

1111
@threading_helper.reap_threads
12-
def test_threading(self):
12+
def test_batched(self):
1313
number_of_threads = 10
1414
number_of_iterations = 20
1515
barrier = Barrier(number_of_threads)
@@ -34,5 +34,31 @@ def work(it):
3434

3535
barrier.reset()
3636

37+
@threading_helper.reap_threads
38+
def test_cycle(self):
39+
number_of_threads = 6
40+
number_of_iterations = 10
41+
number_of_cycles = 400
42+
43+
barrier = Barrier(number_of_threads)
44+
def work(it):
45+
barrier.wait()
46+
for _ in range(number_of_cycles):
47+
_ = next(it)
48+
49+
data = (1, 2, 3, 4)
50+
for it in range(number_of_iterations):
51+
cycle_iterator = cycle(data)
52+
worker_threads = []
53+
for ii in range(number_of_threads):
54+
worker_threads.append(
55+
Thread(target=work, args=[cycle_iterator]))
56+
57+
with threading_helper.start_threads(worker_threads):
58+
pass
59+
60+
barrier.reset()
61+
62+
3763
if __name__ == "__main__":
3864
unittest.main()

Modules/itertoolsmodule.c

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,7 +1166,7 @@ itertools_cycle_impl(PyTypeObject *type, PyObject *iterable)
11661166
}
11671167
lz->it = it;
11681168
lz->saved = saved;
1169-
lz->index = 0;
1169+
lz->index = -1;
11701170

11711171
return (PyObject *)lz;
11721172
}
@@ -1199,7 +1199,9 @@ cycle_next(PyObject *op)
11991199
cycleobject *lz = cycleobject_CAST(op);
12001200
PyObject *item;
12011201

1202-
if (lz->it != NULL) {
1202+
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(lz->index);
1203+
1204+
if (index < 0) {
12031205
item = PyIter_Next(lz->it);
12041206
if (item != NULL) {
12051207
if (PyList_Append(lz->saved, item)) {
@@ -1211,14 +1213,19 @@ cycle_next(PyObject *op)
12111213
/* Note: StopIteration is already cleared by PyIter_Next() */
12121214
if (PyErr_Occurred())
12131215
return NULL;
1216+
index = 0;
1217+
FT_ATOMIC_STORE_SSIZE_RELAXED(lz->index, 0);
1218+
#ifndef Py_GIL_DISABLED
12141219
Py_CLEAR(lz->it);
1220+
#endif
12151221
}
12161222
if (PyList_GET_SIZE(lz->saved) == 0)
12171223
return NULL;
1218-
item = PyList_GET_ITEM(lz->saved, lz->index);
1219-
lz->index++;
1220-
if (lz->index >= PyList_GET_SIZE(lz->saved))
1221-
lz->index = 0;
1224+
item = PyList_GET_ITEM(lz->saved, index);
1225+
index++;
1226+
if (index >= PyList_GET_SIZE(lz->saved)) {
1227+
FT_ATOMIC_STORE_SSIZE_RELAXED(lz->index, 0);
1228+
}
12221229
return Py_NewRef(item);
12231230
}
12241231

0 commit comments

Comments
 (0)