Skip to content

Commit 052ec60

Browse files
committed
TEST: Unit test to check thread-safety of fileslice.read_segments
1 parent 7cbbab2 commit 052ec60

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

nibabel/tests/test_fileslice.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from itertools import product
99
from functools import partial
1010
from distutils.version import LooseVersion
11+
from threading import Thread, Lock
12+
import time
1113

1214
import numpy as np
1315

@@ -689,6 +691,60 @@ def test_read_segments():
689691
assert_raises(Exception, read_segments, fobj, [(0, 100), (100, 200)], 199)
690692

691693

694+
def test_read_segments_lock():
695+
# Test read_segment locking with multiple threads
696+
fobj = BytesIO()
697+
arr = np.random.randint(0, 256, 1000, dtype=np.uint8)
698+
fobj.write(arr.tostring())
699+
700+
# Encourage the interpreter to switch threads between a seek/read pair
701+
def yielding_read(*args, **kwargs):
702+
time.sleep(0.001)
703+
return fobj._real_read(*args, **kwargs)
704+
705+
fobj._real_read = fobj.read
706+
fobj.read = yielding_read
707+
708+
# Generate some random array segments to read from the file
709+
def random_segments(nsegs):
710+
segs = []
711+
nbytes = 0
712+
713+
for i in range(nsegs):
714+
seglo = np.random.randint(0, 998)
715+
seghi = np.random.randint(seglo + 1, 1000)
716+
seglen = seghi - seglo
717+
nbytes += seglen
718+
segs.append([seglo, seglen])
719+
720+
return segs, nbytes
721+
722+
# Get the data that should be returned for the given segments
723+
def get_expected(segs):
724+
segs = [arr[off:off + length] for off, length in segs]
725+
return np.concatenate(segs)
726+
727+
# Read from the file, check the result. We do this task simultaneously in
728+
# many threads. Each thread that passes adds 1 to numpassed[0]
729+
numpassed = [0]
730+
lock = Lock()
731+
732+
def runtest():
733+
seg, nbytes = random_segments(1)
734+
expected = get_expected(seg)
735+
_check_bytes(read_segments(fobj, seg, nbytes, lock), expected)
736+
737+
seg, nbytes = random_segments(10)
738+
expected = get_expected(seg)
739+
_check_bytes(read_segments(fobj, seg, nbytes, lock), expected)
740+
numpassed[0] += 1
741+
742+
threads = [Thread(target=runtest) for i in range(100)]
743+
[t.start() for t in threads]
744+
[t.join() for t in threads]
745+
assert numpassed[0] == len(threads)
746+
747+
692748
def _check_slicer(sliceobj, arr, fobj, offset, order,
693749
heuristic=threshold_heuristic):
694750
new_slice = fileslice(fobj, sliceobj, arr.shape, arr.dtype, offset, order,

0 commit comments

Comments
 (0)