Skip to content

Commit 932ec76

Browse files
committed
bytearray iterator made free-thread safe
1 parent 832cc05 commit 932ec76

File tree

1 file changed

+51
-23
lines changed

1 file changed

+51
-23
lines changed

Objects/bytearrayobject.c

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,31 +2825,45 @@ static PyObject *
28252825
bytearrayiter_next(PyObject *self)
28262826
{
28272827
bytesiterobject *it = _bytesiterobject_CAST(self);
2828-
PyByteArrayObject *seq;
2828+
int val;
28292829

28302830
assert(it != NULL);
2831-
seq = it->it_seq;
2832-
if (seq == NULL)
2831+
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
2832+
if (index < 0) {
28332833
return NULL;
2834+
}
2835+
PyByteArrayObject *seq = it->it_seq;
28342836
assert(PyByteArray_Check(seq));
28352837

2836-
if (it->it_index < PyByteArray_GET_SIZE(seq)) {
2837-
return _PyLong_FromUnsignedChar(
2838-
(unsigned char)PyByteArray_AS_STRING(seq)[it->it_index++]);
2838+
Py_BEGIN_CRITICAL_SECTION(seq);
2839+
if (index < PyByteArray_GET_SIZE(seq)) {
2840+
val = (unsigned char)PyByteArray_AS_STRING(seq)[index];
28392841
}
2842+
else {
2843+
val = -1;
2844+
}
2845+
Py_END_CRITICAL_SECTION();
28402846

2841-
it->it_seq = NULL;
2842-
Py_DECREF(seq);
2843-
return NULL;
2847+
if (val == -1) {
2848+
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, -1);
2849+
#ifndef Py_GIL_DISABLED
2850+
it->seq = NULL;
2851+
Py_DECREF(seq);
2852+
#endif
2853+
return NULL;
2854+
}
2855+
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index + 1);
2856+
return _PyLong_FromUnsignedChar((unsigned char)val);
28442857
}
28452858

28462859
static PyObject *
28472860
bytearrayiter_length_hint(PyObject *self, PyObject *Py_UNUSED(ignored))
28482861
{
28492862
bytesiterobject *it = _bytesiterobject_CAST(self);
28502863
Py_ssize_t len = 0;
2851-
if (it->it_seq) {
2852-
len = PyByteArray_GET_SIZE(it->it_seq) - it->it_index;
2864+
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
2865+
if (index >= 0) {
2866+
len = PyByteArray_GET_SIZE(it->it_seq) - index;
28532867
if (len < 0) {
28542868
len = 0;
28552869
}
@@ -2869,27 +2883,41 @@ bytearrayiter_reduce(PyObject *self, PyObject *Py_UNUSED(ignored))
28692883
* call must be before access of iterator pointers.
28702884
* see issue #101765 */
28712885
bytesiterobject *it = _bytesiterobject_CAST(self);
2872-
if (it->it_seq != NULL) {
2873-
return Py_BuildValue("N(O)n", iter, it->it_seq, it->it_index);
2874-
} else {
2875-
return Py_BuildValue("N(())", iter);
2886+
PyObject *ret = NULL;
2887+
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
2888+
if (index >= 0) {
2889+
Py_BEGIN_CRITICAL_SECTION(it->it_seq);
2890+
if (index <= PyByteArray_GET_SIZE(it->it_seq)) {
2891+
ret = Py_BuildValue("N(O)n", iter, it->it_seq, index);
2892+
}
2893+
Py_END_CRITICAL_SECTION();
28762894
}
2895+
if (ret == NULL) {
2896+
ret = Py_BuildValue("N(())", iter);
2897+
}
2898+
return ret;
28772899
}
28782900

28792901
static PyObject *
28802902
bytearrayiter_setstate(PyObject *self, PyObject *state)
28812903
{
28822904
Py_ssize_t index = PyLong_AsSsize_t(state);
2883-
if (index == -1 && PyErr_Occurred())
2905+
if (index == -1 && PyErr_Occurred()) {
28842906
return NULL;
2907+
}
28852908

28862909
bytesiterobject *it = _bytesiterobject_CAST(self);
2887-
if (it->it_seq != NULL) {
2888-
if (index < 0)
2889-
index = 0;
2890-
else if (index > PyByteArray_GET_SIZE(it->it_seq))
2891-
index = PyByteArray_GET_SIZE(it->it_seq); /* iterator exhausted */
2892-
it->it_index = index;
2910+
if (FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index) >= 0) {
2911+
if (index < -1) {
2912+
index = -1;
2913+
}
2914+
else {
2915+
Py_ssize_t size = PyByteArray_GET_SIZE(it->it_seq);
2916+
if (index > size) {
2917+
index = size; /* iterator at end */
2918+
}
2919+
}
2920+
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index);
28932921
}
28942922
Py_RETURN_NONE;
28952923
}
@@ -2951,7 +2979,7 @@ bytearray_iter(PyObject *seq)
29512979
it = PyObject_GC_New(bytesiterobject, &PyByteArrayIter_Type);
29522980
if (it == NULL)
29532981
return NULL;
2954-
it->it_index = 0;
2982+
it->it_index = 0; // -1 indicates exhausted
29552983
it->it_seq = (PyByteArrayObject *)Py_NewRef(seq);
29562984
_PyObject_GC_TRACK(it);
29572985
return (PyObject *)it;

0 commit comments

Comments
 (0)