Skip to content

Commit 926d734

Browse files
[3.14] gh-116738: make mmap module thread-safe (GH-139237) (#139825)
* [3.14] gh-116738: make `mmap` module thread-safe (GH-139237) (cherry picked from commit 7f155f9) Co-authored-by: Alper <[email protected]>
1 parent 7c03e90 commit 926d734

File tree

5 files changed

+1497
-197
lines changed

5 files changed

+1497
-197
lines changed
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
import unittest
2+
3+
from test.support import import_helper, threading_helper
4+
from test.support.threading_helper import run_concurrently
5+
6+
import os
7+
import string
8+
import tempfile
9+
import threading
10+
11+
from collections import Counter
12+
13+
mmap = import_helper.import_module("mmap")
14+
15+
NTHREADS = 10
16+
ANONYMOUS_MEM = -1
17+
18+
19+
@threading_helper.requires_working_threading()
20+
class MmapTests(unittest.TestCase):
21+
def test_read_and_read_byte(self):
22+
ascii_uppercase = string.ascii_uppercase.encode()
23+
# Choose a total mmap size that evenly divides across threads and the
24+
# read pattern (3 bytes per loop).
25+
mmap_size = 3 * NTHREADS * len(ascii_uppercase)
26+
num_bytes_to_read_per_thread = mmap_size // NTHREADS
27+
bytes_read_from_mmap = []
28+
29+
def read(mm_obj):
30+
nread = 0
31+
while nread < num_bytes_to_read_per_thread:
32+
b = mm_obj.read_byte()
33+
bytes_read_from_mmap.append(b)
34+
b = mm_obj.read(2)
35+
bytes_read_from_mmap.extend(b)
36+
nread += 3
37+
38+
with mmap.mmap(ANONYMOUS_MEM, mmap_size) as mm_obj:
39+
for i in range(mmap_size // len(ascii_uppercase)):
40+
mm_obj.write(ascii_uppercase)
41+
42+
mm_obj.seek(0)
43+
run_concurrently(
44+
worker_func=read,
45+
args=(mm_obj,),
46+
nthreads=NTHREADS,
47+
)
48+
49+
self.assertEqual(len(bytes_read_from_mmap), mmap_size)
50+
# Count each letter/byte to verify read correctness
51+
counter = Counter(bytes_read_from_mmap)
52+
self.assertEqual(len(counter), len(ascii_uppercase))
53+
# Each letter/byte should be read (3 * NTHREADS) times
54+
for letter in ascii_uppercase:
55+
self.assertEqual(counter[letter], 3 * NTHREADS)
56+
57+
def test_readline(self):
58+
num_lines = 1000
59+
lines_read_from_mmap = []
60+
expected_lines = []
61+
62+
def readline(mm_obj):
63+
for i in range(num_lines // NTHREADS):
64+
line = mm_obj.readline()
65+
lines_read_from_mmap.append(line)
66+
67+
# Allocate mmap enough for num_lines (max line 5 bytes including NL)
68+
with mmap.mmap(ANONYMOUS_MEM, num_lines * 5) as mm_obj:
69+
for i in range(num_lines):
70+
line = b"%d\n" % i
71+
mm_obj.write(line)
72+
expected_lines.append(line)
73+
74+
mm_obj.seek(0)
75+
run_concurrently(
76+
worker_func=readline,
77+
args=(mm_obj,),
78+
nthreads=NTHREADS,
79+
)
80+
81+
self.assertEqual(len(lines_read_from_mmap), num_lines)
82+
# Every line should be read once by threads; order is non-deterministic
83+
# Sort numerically by integer value
84+
lines_read_from_mmap.sort(key=lambda x: int(x))
85+
self.assertEqual(lines_read_from_mmap, expected_lines)
86+
87+
def test_write_and_write_byte(self):
88+
thread_letters = list(string.ascii_uppercase)
89+
self.assertLessEqual(NTHREADS, len(thread_letters))
90+
per_thread_write_loop = 100
91+
92+
def write(mm_obj):
93+
# Each thread picks a unique letter to write
94+
thread_letter = thread_letters.pop(0)
95+
thread_bytes = (thread_letter * 2).encode()
96+
for _ in range(per_thread_write_loop):
97+
mm_obj.write_byte(thread_bytes[0])
98+
mm_obj.write(thread_bytes)
99+
100+
with mmap.mmap(
101+
ANONYMOUS_MEM, per_thread_write_loop * 3 * NTHREADS
102+
) as mm_obj:
103+
run_concurrently(
104+
worker_func=write,
105+
args=(mm_obj,),
106+
nthreads=NTHREADS,
107+
)
108+
mm_obj.seek(0)
109+
data = mm_obj.read()
110+
self.assertEqual(len(data), NTHREADS * per_thread_write_loop * 3)
111+
counter = Counter(data)
112+
self.assertEqual(len(counter), NTHREADS)
113+
# Each thread letter should be written `per_thread_write_loop` * 3
114+
for letter in counter:
115+
self.assertEqual(counter[letter], per_thread_write_loop * 3)
116+
117+
def test_move(self):
118+
ascii_uppercase = string.ascii_uppercase.encode()
119+
num_letters = len(ascii_uppercase)
120+
121+
def move(mm_obj):
122+
for i in range(num_letters):
123+
# Move 1 byte from the first half to the second half
124+
mm_obj.move(0 + i, num_letters + i, 1)
125+
126+
with mmap.mmap(ANONYMOUS_MEM, 2 * num_letters) as mm_obj:
127+
mm_obj.write(ascii_uppercase)
128+
run_concurrently(
129+
worker_func=move,
130+
args=(mm_obj,),
131+
nthreads=NTHREADS,
132+
)
133+
134+
def test_seek_and_tell(self):
135+
seek_per_thread = 10
136+
137+
def seek(mm_obj):
138+
self.assertTrue(mm_obj.seekable())
139+
for _ in range(seek_per_thread):
140+
before_seek = mm_obj.tell()
141+
mm_obj.seek(1, os.SEEK_CUR)
142+
self.assertLess(before_seek, mm_obj.tell())
143+
144+
with mmap.mmap(ANONYMOUS_MEM, 1024) as mm_obj:
145+
run_concurrently(
146+
worker_func=seek,
147+
args=(mm_obj,),
148+
nthreads=NTHREADS,
149+
)
150+
# Each thread seeks from current position, the end position should
151+
# be the sum of all seeks from all threads.
152+
self.assertEqual(mm_obj.tell(), NTHREADS * seek_per_thread)
153+
154+
def test_slice_update_and_slice_read(self):
155+
thread_letters = list(string.ascii_uppercase)
156+
self.assertLessEqual(NTHREADS, len(thread_letters))
157+
158+
def slice_update_and_slice_read(mm_obj):
159+
# Each thread picks a unique letter to write
160+
thread_letter = thread_letters.pop(0)
161+
thread_bytes = (thread_letter * 1024).encode()
162+
for _ in range(100):
163+
mm_obj[:] = thread_bytes
164+
read_bytes = mm_obj[:]
165+
# Read bytes should be all the same letter, showing no
166+
# interleaving
167+
self.assertTrue(all_same(read_bytes))
168+
169+
with mmap.mmap(ANONYMOUS_MEM, 1024) as mm_obj:
170+
run_concurrently(
171+
worker_func=slice_update_and_slice_read,
172+
args=(mm_obj,),
173+
nthreads=NTHREADS,
174+
)
175+
176+
def test_item_update_and_item_read(self):
177+
thread_indexes = [i for i in range(NTHREADS)]
178+
179+
def item_update_and_item_read(mm_obj):
180+
# Each thread picks a unique index to write
181+
thread_index = thread_indexes.pop()
182+
for i in range(100):
183+
mm_obj[thread_index] = i
184+
self.assertEqual(mm_obj[thread_index], i)
185+
186+
# Read values set by other threads, all values
187+
# should be less than '100'
188+
for val in mm_obj:
189+
self.assertLess(int.from_bytes(val), 100)
190+
191+
with mmap.mmap(ANONYMOUS_MEM, NTHREADS + 1) as mm_obj:
192+
run_concurrently(
193+
worker_func=item_update_and_item_read,
194+
args=(mm_obj,),
195+
nthreads=NTHREADS,
196+
)
197+
198+
def test_close_and_closed(self):
199+
def close_mmap(mm_obj):
200+
mm_obj.close()
201+
self.assertTrue(mm_obj.closed)
202+
203+
with mmap.mmap(ANONYMOUS_MEM, 1) as mm_obj:
204+
run_concurrently(
205+
worker_func=close_mmap,
206+
args=(mm_obj,),
207+
nthreads=NTHREADS,
208+
)
209+
210+
def test_find_and_rfind(self):
211+
per_thread_loop = 10
212+
213+
def find_and_rfind(mm_obj):
214+
pattern = b'Thread-Ident:"%d"' % threading.get_ident()
215+
mm_obj.write(pattern)
216+
for _ in range(per_thread_loop):
217+
found_at = mm_obj.find(pattern, 0)
218+
self.assertNotEqual(found_at, -1)
219+
# Should not find it after the `found_at`
220+
self.assertEqual(mm_obj.find(pattern, found_at + 1), -1)
221+
found_at_rev = mm_obj.rfind(pattern, 0)
222+
self.assertEqual(found_at, found_at_rev)
223+
# Should not find it after the `found_at`
224+
self.assertEqual(mm_obj.rfind(pattern, found_at + 1), -1)
225+
226+
with mmap.mmap(ANONYMOUS_MEM, 1024) as mm_obj:
227+
run_concurrently(
228+
worker_func=find_and_rfind,
229+
args=(mm_obj,),
230+
nthreads=NTHREADS,
231+
)
232+
233+
def test_mmap_export_as_memoryview(self):
234+
"""
235+
Each thread creates a memoryview and updates the internal state of the
236+
mmap object.
237+
"""
238+
buffer_size = 42
239+
240+
def create_memoryview_from_mmap(mm_obj):
241+
memoryviews = []
242+
for _ in range(100):
243+
mv = memoryview(mm_obj)
244+
memoryviews.append(mv)
245+
self.assertEqual(len(mv), buffer_size)
246+
self.assertEqual(mv[:7], b"CPython")
247+
248+
# Cannot close the mmap while it is exported as buffers
249+
with self.assertRaisesRegex(
250+
BufferError, "cannot close exported pointers exist"
251+
):
252+
mm_obj.close()
253+
254+
with mmap.mmap(ANONYMOUS_MEM, 42) as mm_obj:
255+
mm_obj.write(b"CPython")
256+
run_concurrently(
257+
worker_func=create_memoryview_from_mmap,
258+
args=(mm_obj,),
259+
nthreads=NTHREADS,
260+
)
261+
# Implicit mm_obj.close() verifies all exports (memoryviews) are
262+
# properly freed.
263+
264+
265+
def all_same(lst):
266+
return all(item == lst[0] for item in lst)
267+
268+
269+
if __name__ == "__main__":
270+
unittest.main()

Lib/test/test_mmap.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,10 @@ def test_madvise(self):
887887
size = 2 * PAGESIZE
888888
m = mmap.mmap(-1, size)
889889

890+
class Number:
891+
def __index__(self):
892+
return 2
893+
890894
with self.assertRaisesRegex(ValueError, "madvise start out of bounds"):
891895
m.madvise(mmap.MADV_NORMAL, size)
892896
with self.assertRaisesRegex(ValueError, "madvise start out of bounds"):
@@ -895,10 +899,14 @@ def test_madvise(self):
895899
m.madvise(mmap.MADV_NORMAL, 0, -1)
896900
with self.assertRaisesRegex(OverflowError, "madvise length too large"):
897901
m.madvise(mmap.MADV_NORMAL, PAGESIZE, sys.maxsize)
902+
with self.assertRaisesRegex(
903+
TypeError, "'str' object cannot be interpreted as an integer"):
904+
m.madvise(mmap.MADV_NORMAL, PAGESIZE, "Not a Number")
898905
self.assertEqual(m.madvise(mmap.MADV_NORMAL), None)
899906
self.assertEqual(m.madvise(mmap.MADV_NORMAL, PAGESIZE), None)
900907
self.assertEqual(m.madvise(mmap.MADV_NORMAL, PAGESIZE, size), None)
901908
self.assertEqual(m.madvise(mmap.MADV_NORMAL, 0, 2), None)
909+
self.assertEqual(m.madvise(mmap.MADV_NORMAL, 0, Number()), None)
902910
self.assertEqual(m.madvise(mmap.MADV_NORMAL, 0, size), None)
903911

904912
def test_resize_up_anonymous_mapping(self):
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Make :mod:`mmap` thread-safe on the :term:`free threaded <free threading>`
2+
build.

0 commit comments

Comments
 (0)