Skip to content

Commit 7f155f9

Browse files
authored
gh-116738: make mmap module thread-safe (#139237)
1 parent e7e3d1d commit 7f155f9

File tree

5 files changed

+1562
-201
lines changed

5 files changed

+1562
-201
lines changed
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
import unittest
2+
3+
from test.support import import_helper, threading_helper
4+
from test.support.threading_helper import run_concurrently
5+
6+
import os
7+
import string
8+
import tempfile
9+
import threading
10+
11+
from collections import Counter
12+
13+
mmap = import_helper.import_module("mmap")
14+
15+
NTHREADS = 10
16+
ANONYMOUS_MEM = -1
17+
18+
19+
@threading_helper.requires_working_threading()
20+
class MmapTests(unittest.TestCase):
21+
def test_read_and_read_byte(self):
22+
ascii_uppercase = string.ascii_uppercase.encode()
23+
# Choose a total mmap size that evenly divides across threads and the
24+
# read pattern (3 bytes per loop).
25+
mmap_size = 3 * NTHREADS * len(ascii_uppercase)
26+
num_bytes_to_read_per_thread = mmap_size // NTHREADS
27+
bytes_read_from_mmap = []
28+
29+
def read(mm_obj):
30+
nread = 0
31+
while nread < num_bytes_to_read_per_thread:
32+
b = mm_obj.read_byte()
33+
bytes_read_from_mmap.append(b)
34+
b = mm_obj.read(2)
35+
bytes_read_from_mmap.extend(b)
36+
nread += 3
37+
38+
with mmap.mmap(ANONYMOUS_MEM, mmap_size) as mm_obj:
39+
for i in range(mmap_size // len(ascii_uppercase)):
40+
mm_obj.write(ascii_uppercase)
41+
42+
mm_obj.seek(0)
43+
run_concurrently(
44+
worker_func=read,
45+
args=(mm_obj,),
46+
nthreads=NTHREADS,
47+
)
48+
49+
self.assertEqual(len(bytes_read_from_mmap), mmap_size)
50+
# Count each letter/byte to verify read correctness
51+
counter = Counter(bytes_read_from_mmap)
52+
self.assertEqual(len(counter), len(ascii_uppercase))
53+
# Each letter/byte should be read (3 * NTHREADS) times
54+
for letter in ascii_uppercase:
55+
self.assertEqual(counter[letter], 3 * NTHREADS)
56+
57+
def test_readline(self):
58+
num_lines = 1000
59+
lines_read_from_mmap = []
60+
expected_lines = []
61+
62+
def readline(mm_obj):
63+
for i in range(num_lines // NTHREADS):
64+
line = mm_obj.readline()
65+
lines_read_from_mmap.append(line)
66+
67+
# Allocate mmap enough for num_lines (max line 5 bytes including NL)
68+
with mmap.mmap(ANONYMOUS_MEM, num_lines * 5) as mm_obj:
69+
for i in range(num_lines):
70+
line = b"%d\n" % i
71+
mm_obj.write(line)
72+
expected_lines.append(line)
73+
74+
mm_obj.seek(0)
75+
run_concurrently(
76+
worker_func=readline,
77+
args=(mm_obj,),
78+
nthreads=NTHREADS,
79+
)
80+
81+
self.assertEqual(len(lines_read_from_mmap), num_lines)
82+
# Every line should be read once by threads; order is non-deterministic
83+
# Sort numerically by integer value
84+
lines_read_from_mmap.sort(key=lambda x: int(x))
85+
self.assertEqual(lines_read_from_mmap, expected_lines)
86+
87+
def test_write_and_write_byte(self):
88+
thread_letters = list(string.ascii_uppercase)
89+
self.assertLessEqual(NTHREADS, len(thread_letters))
90+
per_thread_write_loop = 100
91+
92+
def write(mm_obj):
93+
# Each thread picks a unique letter to write
94+
thread_letter = thread_letters.pop(0)
95+
thread_bytes = (thread_letter * 2).encode()
96+
for _ in range(per_thread_write_loop):
97+
mm_obj.write_byte(thread_bytes[0])
98+
mm_obj.write(thread_bytes)
99+
100+
with mmap.mmap(
101+
ANONYMOUS_MEM, per_thread_write_loop * 3 * NTHREADS
102+
) as mm_obj:
103+
run_concurrently(
104+
worker_func=write,
105+
args=(mm_obj,),
106+
nthreads=NTHREADS,
107+
)
108+
mm_obj.seek(0)
109+
data = mm_obj.read()
110+
self.assertEqual(len(data), NTHREADS * per_thread_write_loop * 3)
111+
counter = Counter(data)
112+
self.assertEqual(len(counter), NTHREADS)
113+
# Each thread letter should be written `per_thread_write_loop` * 3
114+
for letter in counter:
115+
self.assertEqual(counter[letter], per_thread_write_loop * 3)
116+
117+
def test_move(self):
118+
ascii_uppercase = string.ascii_uppercase.encode()
119+
num_letters = len(ascii_uppercase)
120+
121+
def move(mm_obj):
122+
for i in range(num_letters):
123+
# Move 1 byte from the first half to the second half
124+
mm_obj.move(0 + i, num_letters + i, 1)
125+
126+
with mmap.mmap(ANONYMOUS_MEM, 2 * num_letters) as mm_obj:
127+
mm_obj.write(ascii_uppercase)
128+
run_concurrently(
129+
worker_func=move,
130+
args=(mm_obj,),
131+
nthreads=NTHREADS,
132+
)
133+
134+
def test_seek_and_tell(self):
135+
seek_per_thread = 10
136+
137+
def seek(mm_obj):
138+
self.assertTrue(mm_obj.seekable())
139+
for _ in range(seek_per_thread):
140+
before_seek = mm_obj.tell()
141+
mm_obj.seek(1, os.SEEK_CUR)
142+
self.assertLess(before_seek, mm_obj.tell())
143+
144+
with mmap.mmap(ANONYMOUS_MEM, 1024) as mm_obj:
145+
run_concurrently(
146+
worker_func=seek,
147+
args=(mm_obj,),
148+
nthreads=NTHREADS,
149+
)
150+
# Each thread seeks from current position, the end position should
151+
# be the sum of all seeks from all threads.
152+
self.assertEqual(mm_obj.tell(), NTHREADS * seek_per_thread)
153+
154+
def test_slice_update_and_slice_read(self):
155+
thread_letters = list(string.ascii_uppercase)
156+
self.assertLessEqual(NTHREADS, len(thread_letters))
157+
158+
def slice_update_and_slice_read(mm_obj):
159+
# Each thread picks a unique letter to write
160+
thread_letter = thread_letters.pop(0)
161+
thread_bytes = (thread_letter * 1024).encode()
162+
for _ in range(100):
163+
mm_obj[:] = thread_bytes
164+
read_bytes = mm_obj[:]
165+
# Read bytes should be all the same letter, showing no
166+
# interleaving
167+
self.assertTrue(all_same(read_bytes))
168+
169+
with mmap.mmap(ANONYMOUS_MEM, 1024) as mm_obj:
170+
run_concurrently(
171+
worker_func=slice_update_and_slice_read,
172+
args=(mm_obj,),
173+
nthreads=NTHREADS,
174+
)
175+
176+
def test_item_update_and_item_read(self):
177+
thread_indexes = [i for i in range(NTHREADS)]
178+
179+
def item_update_and_item_read(mm_obj):
180+
# Each thread picks a unique index to write
181+
thread_index = thread_indexes.pop()
182+
for i in range(100):
183+
mm_obj[thread_index] = i
184+
self.assertEqual(mm_obj[thread_index], i)
185+
186+
# Read values set by other threads, all values
187+
# should be less than '100'
188+
for val in mm_obj:
189+
self.assertLess(int.from_bytes(val), 100)
190+
191+
with mmap.mmap(ANONYMOUS_MEM, NTHREADS + 1) as mm_obj:
192+
run_concurrently(
193+
worker_func=item_update_and_item_read,
194+
args=(mm_obj,),
195+
nthreads=NTHREADS,
196+
)
197+
198+
@unittest.skipUnless(os.name == "posix", "requires Posix")
199+
@unittest.skipUnless(hasattr(mmap.mmap, "resize"), "requires mmap.resize")
200+
def test_resize_and_size(self):
201+
thread_indexes = [i for i in range(NTHREADS)]
202+
203+
def resize_and_item_update(mm_obj):
204+
# Each thread picks a unique index to write
205+
thread_index = thread_indexes.pop()
206+
mm_obj.resize(2048)
207+
self.assertEqual(mm_obj.size(), 2048)
208+
for i in range(100):
209+
mm_obj[thread_index] = i
210+
self.assertEqual(mm_obj[thread_index], i)
211+
212+
with mmap.mmap(ANONYMOUS_MEM, 1024, flags=mmap.MAP_PRIVATE) as mm_obj:
213+
run_concurrently(
214+
worker_func=resize_and_item_update,
215+
args=(mm_obj,),
216+
nthreads=NTHREADS,
217+
)
218+
219+
def test_close_and_closed(self):
220+
def close_mmap(mm_obj):
221+
mm_obj.close()
222+
self.assertTrue(mm_obj.closed)
223+
224+
with mmap.mmap(ANONYMOUS_MEM, 1) as mm_obj:
225+
run_concurrently(
226+
worker_func=close_mmap,
227+
args=(mm_obj,),
228+
nthreads=NTHREADS,
229+
)
230+
231+
def test_find_and_rfind(self):
232+
per_thread_loop = 10
233+
234+
def find_and_rfind(mm_obj):
235+
pattern = b'Thread-Ident:"%d"' % threading.get_ident()
236+
mm_obj.write(pattern)
237+
for _ in range(per_thread_loop):
238+
found_at = mm_obj.find(pattern, 0)
239+
self.assertNotEqual(found_at, -1)
240+
# Should not find it after the `found_at`
241+
self.assertEqual(mm_obj.find(pattern, found_at + 1), -1)
242+
found_at_rev = mm_obj.rfind(pattern, 0)
243+
self.assertEqual(found_at, found_at_rev)
244+
# Should not find it after the `found_at`
245+
self.assertEqual(mm_obj.rfind(pattern, found_at + 1), -1)
246+
247+
with mmap.mmap(ANONYMOUS_MEM, 1024) as mm_obj:
248+
run_concurrently(
249+
worker_func=find_and_rfind,
250+
args=(mm_obj,),
251+
nthreads=NTHREADS,
252+
)
253+
254+
@unittest.skipUnless(os.name == "posix", "requires Posix")
255+
@unittest.skipUnless(hasattr(mmap.mmap, "resize"), "requires mmap.resize")
256+
def test_flush(self):
257+
mmap_filename = "test_mmap_file"
258+
resize_to = 1024
259+
260+
def resize_and_flush(mm_obj):
261+
mm_obj.resize(resize_to)
262+
mm_obj.flush()
263+
264+
with tempfile.TemporaryDirectory() as tmpdirname:
265+
file_path = f"{tmpdirname}/{mmap_filename}"
266+
with open(file_path, "wb+") as file:
267+
file.write(b"CPython")
268+
file.flush()
269+
with mmap.mmap(file.fileno(), 1) as mm_obj:
270+
run_concurrently(
271+
worker_func=resize_and_flush,
272+
args=(mm_obj,),
273+
nthreads=NTHREADS,
274+
)
275+
276+
self.assertEqual(os.path.getsize(file_path), resize_to)
277+
278+
def test_mmap_export_as_memoryview(self):
279+
"""
280+
Each thread creates a memoryview and updates the internal state of the
281+
mmap object.
282+
"""
283+
buffer_size = 42
284+
285+
def create_memoryview_from_mmap(mm_obj):
286+
memoryviews = []
287+
for _ in range(100):
288+
mv = memoryview(mm_obj)
289+
memoryviews.append(mv)
290+
self.assertEqual(len(mv), buffer_size)
291+
self.assertEqual(mv[:7], b"CPython")
292+
293+
# Cannot close the mmap while it is exported as buffers
294+
with self.assertRaisesRegex(
295+
BufferError, "cannot close exported pointers exist"
296+
):
297+
mm_obj.close()
298+
299+
with mmap.mmap(ANONYMOUS_MEM, 42) as mm_obj:
300+
mm_obj.write(b"CPython")
301+
run_concurrently(
302+
worker_func=create_memoryview_from_mmap,
303+
args=(mm_obj,),
304+
nthreads=NTHREADS,
305+
)
306+
# Implicit mm_obj.close() verifies all exports (memoryviews) are
307+
# properly freed.
308+
309+
310+
def all_same(lst):
311+
return all(item == lst[0] for item in lst)
312+
313+
314+
if __name__ == "__main__":
315+
unittest.main()

Lib/test/test_mmap.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,10 @@ def test_madvise(self):
871871
size = 2 * PAGESIZE
872872
m = mmap.mmap(-1, size)
873873

874+
class Number:
875+
def __index__(self):
876+
return 2
877+
874878
with self.assertRaisesRegex(ValueError, "madvise start out of bounds"):
875879
m.madvise(mmap.MADV_NORMAL, size)
876880
with self.assertRaisesRegex(ValueError, "madvise start out of bounds"):
@@ -879,10 +883,14 @@ def test_madvise(self):
879883
m.madvise(mmap.MADV_NORMAL, 0, -1)
880884
with self.assertRaisesRegex(OverflowError, "madvise length too large"):
881885
m.madvise(mmap.MADV_NORMAL, PAGESIZE, sys.maxsize)
886+
with self.assertRaisesRegex(
887+
TypeError, "'str' object cannot be interpreted as an integer"):
888+
m.madvise(mmap.MADV_NORMAL, PAGESIZE, "Not a Number")
882889
self.assertEqual(m.madvise(mmap.MADV_NORMAL), None)
883890
self.assertEqual(m.madvise(mmap.MADV_NORMAL, PAGESIZE), None)
884891
self.assertEqual(m.madvise(mmap.MADV_NORMAL, PAGESIZE, size), None)
885892
self.assertEqual(m.madvise(mmap.MADV_NORMAL, 0, 2), None)
893+
self.assertEqual(m.madvise(mmap.MADV_NORMAL, 0, Number()), None)
886894
self.assertEqual(m.madvise(mmap.MADV_NORMAL, 0, size), None)
887895

888896
@unittest.skipUnless(hasattr(mmap.mmap, 'resize'), 'requires mmap.resize')
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Make :mod:`mmap` thread-safe on the :term:`free threaded <free threading>`
2+
build.

0 commit comments

Comments
 (0)