Skip to content
Merged
31 changes: 28 additions & 3 deletions Lib/test/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2455,9 +2455,6 @@ def check(funcs, a=None, *args):
with threading_helper.start_threads(threads):
pass

for thread in threads:
threading_helper.join_thread(thread)

# hard errors

check([clear] + [reduce] * 10)
Expand Down Expand Up @@ -2519,6 +2516,34 @@ def check(funcs, a=None, *args):
check([clear] + [upper] * 10, bytearray(b'a' * 0x400000))
check([clear] + [zfill] * 10, bytearray(b'1' * 0x200000))

@unittest.skipUnless(support.Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_free_threading_bytearrayiter(self):
# Non-deterministic but good chance to fail if bytearrayiter is not free-threading safe.
# We are fishing for a "Assertion failed: object has negative ref count".

def iter_next(b, it):
b.wait()
list(it)

def check(funcs, it):
barrier = threading.Barrier(len(funcs))
threads = []

for func in funcs:
thread = threading.Thread(target=func, args=(barrier, it))

threads.append(thread)

with threading_helper.start_threads(threads):
pass

for _ in range(10):
ba = bytearray(b'0' * 0x4000) # this is a load-bearing variable, do not remove

check([iter_next] * 10, iter(ba))


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make ``bytearray`` iterator safe under :term:`free threading`.
71 changes: 48 additions & 23 deletions Objects/bytearrayobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -2856,31 +2856,44 @@ static PyObject *
bytearrayiter_next(PyObject *self)
{
bytesiterobject *it = _bytesiterobject_CAST(self);
PyByteArrayObject *seq;
int val;

assert(it != NULL);
seq = it->it_seq;
if (seq == NULL)
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
if (index < 0) {
return NULL;
}
PyByteArrayObject *seq = it->it_seq;
assert(PyByteArray_Check(seq));

if (it->it_index < PyByteArray_GET_SIZE(seq)) {
return _PyLong_FromUnsignedChar(
(unsigned char)PyByteArray_AS_STRING(seq)[it->it_index++]);
Py_BEGIN_CRITICAL_SECTION(seq);
if (index < Py_SIZE(seq)) {
val = (unsigned char)PyByteArray_AS_STRING(seq)[index];
}
else {
val = -1;
}
Py_END_CRITICAL_SECTION();

it->it_seq = NULL;
Py_DECREF(seq);
return NULL;
if (val == -1) {
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, -1);
#ifndef Py_GIL_DISABLED
Py_CLEAR(it->it_seq);
#endif
return NULL;
}
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index + 1);
return _PyLong_FromUnsignedChar((unsigned char)val);
}

static PyObject *
bytearrayiter_length_hint(PyObject *self, PyObject *Py_UNUSED(ignored))
{
bytesiterobject *it = _bytesiterobject_CAST(self);
Py_ssize_t len = 0;
if (it->it_seq) {
len = PyByteArray_GET_SIZE(it->it_seq) - it->it_index;
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
if (index >= 0) {
len = PyByteArray_GET_SIZE(it->it_seq) - index;
if (len < 0) {
len = 0;
}
Expand All @@ -2900,27 +2913,39 @@ bytearrayiter_reduce(PyObject *self, PyObject *Py_UNUSED(ignored))
* call must be before access of iterator pointers.
* see issue #101765 */
bytesiterobject *it = _bytesiterobject_CAST(self);
if (it->it_seq != NULL) {
return Py_BuildValue("N(O)n", iter, it->it_seq, it->it_index);
} else {
return Py_BuildValue("N(())", iter);
PyObject *ret = NULL;
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
if (index >= 0) {
if (index <= PyByteArray_GET_SIZE(it->it_seq)) {
ret = Py_BuildValue("N(O)n", iter, it->it_seq, index);
}
}
if (ret == NULL) {
ret = Py_BuildValue("N(())", iter);
}
return ret;
}

static PyObject *
bytearrayiter_setstate(PyObject *self, PyObject *state)
{
Py_ssize_t index = PyLong_AsSsize_t(state);
if (index == -1 && PyErr_Occurred())
if (index == -1 && PyErr_Occurred()) {
return NULL;
}

bytesiterobject *it = _bytesiterobject_CAST(self);
if (it->it_seq != NULL) {
if (index < 0)
index = 0;
else if (index > PyByteArray_GET_SIZE(it->it_seq))
index = PyByteArray_GET_SIZE(it->it_seq); /* iterator exhausted */
it->it_index = index;
if (FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index) >= 0) {
if (index < -1) {
index = -1;
}
else {
Py_ssize_t size = PyByteArray_GET_SIZE(it->it_seq);
if (index > size) {
index = size; /* iterator at end */
}
}
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index);
}
Py_RETURN_NONE;
}
Expand Down Expand Up @@ -2982,7 +3007,7 @@ bytearray_iter(PyObject *seq)
it = PyObject_GC_New(bytesiterobject, &PyByteArrayIter_Type);
if (it == NULL)
return NULL;
it->it_index = 0;
it->it_index = 0; // -1 indicates exhausted
it->it_seq = (PyByteArrayObject *)Py_NewRef(seq);
_PyObject_GC_TRACK(it);
return (PyObject *)it;
Expand Down
Loading