Skip to content
Merged
335 changes: 335 additions & 0 deletions Lib/test/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
import copy
import functools
import pickle
import sysconfig
import tempfile
import textwrap
import threading
import unittest

import test.support
from test.support import import_helper
from test.support import threading_helper
from test.support import warnings_helper
import test.string_tests
import test.list_tests
Expand Down Expand Up @@ -2185,5 +2188,337 @@ class BytesSubclassTest(SubclassTest, unittest.TestCase):
type2test = BytesSubclass


class FreeThreadingTest(unittest.TestCase):
@unittest.skipUnless(sysconfig.get_config_var('Py_GIL_DISABLED'),
'this test can only possibly fail with GIL disabled')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_free_threading_bytearray(self):
# Test pretty much everything that can break under free-threading.
# Non-deterministic, but at least one of these things will fail if
# bytearray module is not free-thread safe.

def clear(b, a, *args): # MODIFIES!
b.wait()
try: a.clear()
except BufferError: pass

def clear2(b, a, c): # MODIFIES c!
b.wait()
try: c.clear()
except BufferError: pass

def pop1(b, a): # MODIFIES!
b.wait()
try: a.pop()
except IndexError: pass

def append1(b, a): # MODIFIES!
b.wait()
a.append(0)

def insert1(b, a): # MODIFIES!
b.wait()
a.insert(0, 0)

def extend(b, a): # MODIFIES!
c = bytearray(b'0' * 0x400000)
b.wait()
a.extend(c)

def remove(b, a): # MODIFIES!
c = ord('0')
b.wait()
try: a.remove(c)
except ValueError: pass

def reverse(b, a): # modifies inplace
b.wait()
a.reverse()

def reduce(b, a):
b.wait()
a.__reduce__()

def reduceex2(b, a):
b.wait()
a.__reduce_ex__(2)

def reduceex3(b, a):
b.wait()
c = a.__reduce_ex__(3)
assert not c[1] or 0xdd not in c[1][0]

def count0(b, a):
b.wait()
a.count(0)

def decode(b, a):
b.wait()
a.decode()

def find(b, a):
c = bytearray(b'0' * 0x40000)
b.wait()
a.find(c)

def hex(b, a):
b.wait()
a.hex('_')

def join(b, a):
b.wait()
a.join([b'1', b'2', b'3'])

def replace(b, a):
b.wait()
a.replace(b'0', b'')

def maketrans(b, a, c):
b.wait()
try: a.maketrans(a, c)
except ValueError: pass

def translate(b, a, c):
b.wait()
a.translate(c)

def copy(b, a):
b.wait()
c = a.copy()
if c: assert c[0] == 48 # '0'

def endswith(b, a):
b.wait()
assert not a.endswith(b'\xdd')

def index(b, a):
b.wait()
try: a.index(b'\xdd')
except ValueError: return
assert False

def lstrip(b, a):
b.wait()
assert not a.lstrip(b'0')

def partition(b, a):
b.wait()
assert not a.partition(b'\xdd')[2]

def removeprefix(b, a):
b.wait()
assert not a.removeprefix(b'0')

def removesuffix(b, a):
b.wait()
assert not a.removesuffix(b'0')

def rfind(b, a):
b.wait()
assert a.rfind(b'\xdd') == -1

def rindex(b, a):
b.wait()
try: a.rindex(b'\xdd')
except ValueError: return
assert False

def rpartition(b, a):
b.wait()
assert not a.rpartition(b'\xdd')[0]

def rsplit(b, a):
b.wait()
assert len(a.rsplit(b'\xdd')) == 1

def rstrip(b, a):
b.wait()
assert not a.rstrip(b'0')

def split(b, a):
b.wait()
assert len(a.split(b'\xdd')) == 1

def splitlines(b, a):
b.wait()
l = len(a.splitlines())
assert l > 1 or l == 0

def startswith(b, a):
b.wait()
assert not a.startswith(b'\xdd')

def strip(b, a):
b.wait()
assert not a.strip(b'0')

def repeat(b, a):
b.wait()
a * 2

def contains(b, a):
b.wait()
assert 0xdd not in a

def iconcat(b, a): # MODIFIES!
c = bytearray(b'0' * 0x400000)
b.wait()
a += c

def irepeat(b, a): # MODIFIES!
b.wait()
a *= 2

def subscript(b, a):
b.wait()
try: assert a[0] != 0xdd
except IndexError: pass

def ass_subscript(b, a): # MODIFIES!
c = bytearray(b'0' * 0x400000)
b.wait()
a[:] = c

def mod(b, a):
c = tuple(range(4096))
b.wait()
try: a % c
except TypeError: pass

def repr_(b, a):
b.wait()
repr(a)

def capitalize(b, a):
b.wait()
c = a.capitalize()
assert not c or c[0] not in (0xdd, 0xcd)

def center(b, a):
b.wait()
c = a.center(0x60000)
assert not c or c[0x20000] not in (0xdd, 0xcd)

def expandtabs(b, a):
b.wait()
c = a.expandtabs()
assert not c or c[0] not in (0xdd, 0xcd)

def ljust(b, a):
b.wait()
c = a.ljust(0x600000)
assert not c or c[0] not in (0xdd, 0xcd)

def lower(b, a):
b.wait()
c = a.lower()
assert not c or c[0] not in (0xdd, 0xcd)

def rjust(b, a):
b.wait()
c = a.rjust(0x600000)
assert not c or c[-1] not in (0xdd, 0xcd)

def swapcase(b, a):
b.wait()
c = a.swapcase()
assert not c or c[-1] not in (0xdd, 0xcd)

def title(b, a):
b.wait()
c = a.title()
assert not c or c[-1] not in (0xdd, 0xcd)

def upper(b, a):
b.wait()
c = a.upper()
assert not c or c[-1] not in (0xdd, 0xcd)

def zfill(b, a):
b.wait()
c = a.zfill(0x400000)
assert not c or c[-1] not in (0xdd, 0xcd)

def check(funcs, a=None, *args):
if a is None:
a = bytearray(b'0' * 0x400000)

barrier = threading.Barrier(len(funcs))
threads = []

for func in funcs:
thread = threading.Thread(target=func, args=(barrier, a, *args))

threads.append(thread)

with threading_helper.start_threads(threads):
pass

for thread in threads:
threading_helper.join_thread(thread)

# hard errors

check([clear] + [reduce] * 10)
check([clear] + [reduceex2] * 10)
check([clear] + [append1] * 10)
check([clear] * 10)
check([clear] + [count0] * 10)
check([clear] + [decode] * 10)
check([clear] + [extend] * 10)
check([clear] + [find] * 10)
check([clear] + [hex] * 10)
check([clear] + [insert1] * 10)
check([clear] + [join] * 10)
check([clear] + [pop1] * 10)
check([clear] + [remove] * 10)
check([clear] + [replace] * 10)
check([clear] + [reverse] * 10)
check([clear, clear2] + [maketrans] * 10, bytearray(range(128)), bytearray(range(128)))
check([clear] + [translate] * 10, None, bytearray.maketrans(bytearray(range(128)), bytearray(range(128))))

check([clear] + [repeat] * 10)
check([clear] + [iconcat] * 10)
check([clear] + [irepeat] * 10)
check([clear] + [ass_subscript] * 10)
check([clear] + [repr_] * 10)

# value errors

check([clear] + [reduceex3] * 10, bytearray(b'a' * 0x40000))
check([clear] + [copy] * 10)
check([clear] + [endswith] * 10)
check([clear] + [index] * 10)
check([clear] + [lstrip] * 10)
check([clear] + [partition] * 10)
check([clear] + [removeprefix] * 10, bytearray(b'0'))
check([clear] + [removesuffix] * 10, bytearray(b'0'))
check([clear] + [rfind] * 10)
check([clear] + [rindex] * 10)
check([clear] + [rpartition] * 10)
check([clear] + [rsplit] * 10, bytearray(b'0' * 0x4000))
check([clear] + [rstrip] * 10)
check([clear] + [split] * 10, bytearray(b'0' * 0x4000))
check([clear] + [splitlines] * 10, bytearray(b'\n' * 0x400))
check([clear] + [startswith] * 10)
check([clear] + [strip] * 10)

check([clear] + [contains] * 10)
check([clear] + [subscript] * 10)
check([clear] + [mod] * 10, bytearray(b'%d' * 4096))

check([clear] + [capitalize] * 10, bytearray(b'a' * 0x40000))
check([clear] + [center] * 10, bytearray(b'a' * 0x40000))
check([clear] + [expandtabs] * 10, bytearray(b'0\t' * 4096))
check([clear] + [ljust] * 10, bytearray(b'0' * 0x400000))
check([clear] + [lower] * 10, bytearray(b'A' * 0x400000))
check([clear] + [rjust] * 10, bytearray(b'0' * 0x400000))
check([clear] + [swapcase] * 10, bytearray(b'aA' * 0x200000))
check([clear] + [title] * 10, bytearray(b'aA' * 0x200000))
check([clear] + [upper] * 10, bytearray(b'a' * 0x400000))
check([clear] + [zfill] * 10, bytearray(b'1' * 0x200000))


if __name__ == "__main__":
unittest.main()
Loading
Loading