Skip to content

Commit cac384d

Browse files
committed
gh-129107: make bytearrayiter free-threading safe
1 parent 07f5e33 commit cac384d

File tree

2 files changed

+89
-23
lines changed

2 files changed

+89
-23
lines changed

Lib/test/test_bytes.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
import copy
1212
import functools
1313
import pickle
14+
import sysconfig
1415
import tempfile
1516
import textwrap
17+
import threading
1618
import unittest
1719

1820
import test.support
1921
from test.support import import_helper
22+
from test.support import threading_helper
2023
from test.support import warnings_helper
2124
import test.string_tests
2225
import test.list_tests
@@ -2185,5 +2188,39 @@ class BytesSubclassTest(SubclassTest, unittest.TestCase):
21852188
type2test = BytesSubclass
21862189

21872190

2191+
class FreeThreadingTest(unittest.TestCase):
2192+
@unittest.skipUnless(sysconfig.get_config_var('Py_GIL_DISABLED'),
2193+
'this test can only possibly fail with GIL disabled')
2194+
@threading_helper.reap_threads
2195+
@threading_helper.requires_working_threading()
2196+
def test_free_threading_bytearrayiter(self):
2197+
# Non-deterministic but good chance to fail if bytearrayiter is not free-threading safe.
2198+
# We are fishing for a "Assertion failed: object has negative ref count".
2199+
2200+
def iter_next(b, it):
2201+
b.wait()
2202+
list(it)
2203+
2204+
def check(funcs, it):
2205+
barrier = threading.Barrier(len(funcs))
2206+
threads = []
2207+
2208+
for func in funcs:
2209+
thread = threading.Thread(target=func, args=(barrier, it))
2210+
2211+
threads.append(thread)
2212+
2213+
with threading_helper.start_threads(threads):
2214+
pass
2215+
2216+
for thread in threads:
2217+
threading_helper.join_thread(thread)
2218+
2219+
for _ in range(10):
2220+
ba = bytearray(b'0' * 0x4000) # this is a load-bearing variable, do not remove
2221+
2222+
check([iter_next] * 10, iter(ba))
2223+
2224+
21882225
if __name__ == "__main__":
21892226
unittest.main()

Objects/bytearrayobject.c

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "pycore_bytes_methods.h"
66
#include "pycore_bytesobject.h"
77
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
8+
#include "pycore_critical_section.h"
89
#include "pycore_object.h" // _PyObject_GC_UNTRACK()
910
#include "pycore_strhex.h" // _Py_strhex_with_sep()
1011
#include "pycore_long.h" // _PyLong_FromUnsignedChar()
@@ -2519,31 +2520,45 @@ static PyObject *
25192520
bytearrayiter_next(PyObject *self)
25202521
{
25212522
bytesiterobject *it = _bytesiterobject_CAST(self);
2522-
PyByteArrayObject *seq;
2523+
int val;
25232524

25242525
assert(it != NULL);
2525-
seq = it->it_seq;
2526-
if (seq == NULL)
2526+
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
2527+
if (index < 0) {
25272528
return NULL;
2529+
}
2530+
PyByteArrayObject *seq = it->it_seq;
25282531
assert(PyByteArray_Check(seq));
25292532

2530-
if (it->it_index < PyByteArray_GET_SIZE(seq)) {
2531-
return _PyLong_FromUnsignedChar(
2532-
(unsigned char)PyByteArray_AS_STRING(seq)[it->it_index++]);
2533+
Py_BEGIN_CRITICAL_SECTION(seq);
2534+
if (index < PyByteArray_GET_SIZE(seq)) {
2535+
val = (unsigned char)PyByteArray_AS_STRING(seq)[index];
25332536
}
2537+
else {
2538+
val = -1;
2539+
}
2540+
Py_END_CRITICAL_SECTION();
25342541

2535-
it->it_seq = NULL;
2536-
Py_DECREF(seq);
2537-
return NULL;
2542+
if (val == -1) {
2543+
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, -1);
2544+
#ifndef Py_GIL_DISABLED
2545+
it->it_seq = NULL;
2546+
Py_DECREF(seq);
2547+
#endif
2548+
return NULL;
2549+
}
2550+
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index + 1);
2551+
return _PyLong_FromUnsignedChar((unsigned char)val);
25382552
}
25392553

25402554
static PyObject *
25412555
bytearrayiter_length_hint(PyObject *self, PyObject *Py_UNUSED(ignored))
25422556
{
25432557
bytesiterobject *it = _bytesiterobject_CAST(self);
25442558
Py_ssize_t len = 0;
2545-
if (it->it_seq) {
2546-
len = PyByteArray_GET_SIZE(it->it_seq) - it->it_index;
2559+
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
2560+
if (index >= 0) {
2561+
len = PyByteArray_GET_SIZE(it->it_seq) - index;
25472562
if (len < 0) {
25482563
len = 0;
25492564
}
@@ -2563,27 +2578,41 @@ bytearrayiter_reduce(PyObject *self, PyObject *Py_UNUSED(ignored))
25632578
* call must be before access of iterator pointers.
25642579
* see issue #101765 */
25652580
bytesiterobject *it = _bytesiterobject_CAST(self);
2566-
if (it->it_seq != NULL) {
2567-
return Py_BuildValue("N(O)n", iter, it->it_seq, it->it_index);
2568-
} else {
2569-
return Py_BuildValue("N(())", iter);
2581+
PyObject *ret = NULL;
2582+
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
2583+
if (index >= 0) {
2584+
Py_BEGIN_CRITICAL_SECTION(it->it_seq);
2585+
if (index <= PyByteArray_GET_SIZE(it->it_seq)) {
2586+
ret = Py_BuildValue("N(O)n", iter, it->it_seq, index);
2587+
}
2588+
Py_END_CRITICAL_SECTION();
2589+
}
2590+
if (ret == NULL) {
2591+
ret = Py_BuildValue("N(())", iter);
25702592
}
2593+
return ret;
25712594
}
25722595

25732596
static PyObject *
25742597
bytearrayiter_setstate(PyObject *self, PyObject *state)
25752598
{
25762599
Py_ssize_t index = PyLong_AsSsize_t(state);
2577-
if (index == -1 && PyErr_Occurred())
2600+
if (index == -1 && PyErr_Occurred()) {
25782601
return NULL;
2602+
}
25792603

25802604
bytesiterobject *it = _bytesiterobject_CAST(self);
2581-
if (it->it_seq != NULL) {
2582-
if (index < 0)
2583-
index = 0;
2584-
else if (index > PyByteArray_GET_SIZE(it->it_seq))
2585-
index = PyByteArray_GET_SIZE(it->it_seq); /* iterator exhausted */
2586-
it->it_index = index;
2605+
if (FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index) >= 0) {
2606+
if (index < -1) {
2607+
index = -1;
2608+
}
2609+
else {
2610+
Py_ssize_t size = PyByteArray_GET_SIZE(it->it_seq);
2611+
if (index > size) {
2612+
index = size; /* iterator at end */
2613+
}
2614+
}
2615+
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index);
25872616
}
25882617
Py_RETURN_NONE;
25892618
}
@@ -2645,7 +2674,7 @@ bytearray_iter(PyObject *seq)
26452674
it = PyObject_GC_New(bytesiterobject, &PyByteArrayIter_Type);
26462675
if (it == NULL)
26472676
return NULL;
2648-
it->it_index = 0;
2677+
it->it_index = 0; // -1 indicates exhausted
26492678
it->it_seq = (PyByteArrayObject *)Py_NewRef(seq);
26502679
_PyObject_GC_TRACK(it);
26512680
return (PyObject *)it;

0 commit comments

Comments
 (0)