|
8 | 8 | from itertools import product
|
9 | 9 | from functools import partial
|
10 | 10 | from distutils.version import LooseVersion
|
| 11 | +from threading import Thread, Lock |
| 12 | +import time |
11 | 13 |
|
12 | 14 | import numpy as np
|
13 | 15 |
|
@@ -689,6 +691,60 @@ def test_read_segments():
|
689 | 691 | assert_raises(Exception, read_segments, fobj, [(0, 100), (100, 200)], 199)
|
690 | 692 |
|
691 | 693 |
|
| 694 | +def test_read_segments_lock(): |
| 695 | + # Test read_segment locking with multiple threads |
| 696 | + fobj = BytesIO() |
| 697 | + arr = np.random.randint(0, 256, 1000, dtype=np.uint8) |
| 698 | + fobj.write(arr.tostring()) |
| 699 | + |
| 700 | + # Encourage the interpreter to switch threads between a seek/read pair |
| 701 | + def yielding_read(*args, **kwargs): |
| 702 | + time.sleep(0.001) |
| 703 | + return fobj._real_read(*args, **kwargs) |
| 704 | + |
| 705 | + fobj._real_read = fobj.read |
| 706 | + fobj.read = yielding_read |
| 707 | + |
| 708 | + # Generate some random array segments to read from the file |
| 709 | + def random_segments(nsegs): |
| 710 | + segs = [] |
| 711 | + nbytes = 0 |
| 712 | + |
| 713 | + for i in range(nsegs): |
| 714 | + seglo = np.random.randint(0, 998) |
| 715 | + seghi = np.random.randint(seglo + 1, 1000) |
| 716 | + seglen = seghi - seglo |
| 717 | + nbytes += seglen |
| 718 | + segs.append([seglo, seglen]) |
| 719 | + |
| 720 | + return segs, nbytes |
| 721 | + |
| 722 | + # Get the data that should be returned for the given segments |
| 723 | + def get_expected(segs): |
| 724 | + segs = [arr[off:off + length] for off, length in segs] |
| 725 | + return np.concatenate(segs) |
| 726 | + |
| 727 | + # Read from the file, check the result. We do this task simultaneously in |
| 728 | + # many threads. Each thread that passes adds 1 to numpassed[0] |
| 729 | + numpassed = [0] |
| 730 | + lock = Lock() |
| 731 | + |
| 732 | + def runtest(): |
| 733 | + seg, nbytes = random_segments(1) |
| 734 | + expected = get_expected(seg) |
| 735 | + _check_bytes(read_segments(fobj, seg, nbytes, lock), expected) |
| 736 | + |
| 737 | + seg, nbytes = random_segments(10) |
| 738 | + expected = get_expected(seg) |
| 739 | + _check_bytes(read_segments(fobj, seg, nbytes, lock), expected) |
| 740 | + numpassed[0] += 1 |
| 741 | + |
| 742 | + threads = [Thread(target=runtest) for i in range(100)] |
| 743 | + [t.start() for t in threads] |
| 744 | + [t.join() for t in threads] |
| 745 | + assert numpassed[0] == len(threads) |
| 746 | + |
| 747 | + |
692 | 748 | def _check_slicer(sliceobj, arr, fobj, offset, order,
|
693 | 749 | heuristic=threshold_heuristic):
|
694 | 750 | new_slice = fileslice(fobj, sliceobj, arr.shape, arr.dtype, offset, order,
|
|
0 commit comments