Skip to content

Commit 234fb8b

Browse files
committed
Add threading test for C extension
We already have some threading coverage in the general reader tests, but this provides more detailed implementation tests for the extension.
1 parent 9ae9a71 commit 234fb8b

File tree

1 file changed

+273
-0
lines changed

1 file changed

+273
-0
lines changed

tests/threading_test.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""Tests for thread-safety and free-threading support."""
2+
3+
from __future__ import annotations
4+
5+
import threading
6+
import time
7+
import unittest
8+
from typing import TYPE_CHECKING
9+
10+
if TYPE_CHECKING:
11+
from maxminddb.types import Record
12+
13+
try:
14+
import maxminddb.extension # noqa: F401
15+
HAS_EXTENSION = True
16+
except ImportError:
17+
HAS_EXTENSION = False
18+
19+
from maxminddb import open_database
20+
from maxminddb.const import MODE_MMAP_EXT
21+
22+
23+
@unittest.skipIf(
24+
not HAS_EXTENSION,
25+
"No C extension module found. Skipping threading tests",
26+
)
27+
class TestThreadSafety(unittest.TestCase):
28+
"""Test thread safety of the C extension."""
29+
30+
def test_concurrent_reads(self) -> None:
31+
"""Test multiple threads reading concurrently."""
32+
reader = open_database(
33+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
34+
MODE_MMAP_EXT,
35+
)
36+
37+
results: list[Record | None] = [None] * 100
38+
errors: list[Exception] = []
39+
40+
def lookup(index: int, ip: str) -> None:
41+
try:
42+
results[index] = reader.get(ip)
43+
except Exception as e: # noqa: BLE001
44+
errors.append(e)
45+
46+
threads = []
47+
for i in range(100):
48+
ip = f"1.1.1.{(i % 32) + 1}"
49+
t = threading.Thread(target=lookup, args=(i, ip))
50+
threads.append(t)
51+
t.start()
52+
53+
for t in threads:
54+
t.join()
55+
56+
reader.close()
57+
58+
self.assertEqual(len(errors), 0, f"Errors during concurrent reads: {errors}")
59+
# All lookups should have completed
60+
self.assertNotIn(None, results)
61+
62+
def test_read_during_close(self) -> None:
63+
"""Test that close is safe when reads are happening concurrently."""
64+
reader = open_database(
65+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
66+
MODE_MMAP_EXT,
67+
)
68+
69+
errors: list[Exception] = []
70+
should_stop = threading.Event()
71+
72+
def continuous_reader() -> None:
73+
# Keep reading until signaled to stop or reader is closed
74+
while not should_stop.is_set():
75+
try:
76+
reader.get("1.1.1.1")
77+
except ValueError as e: # noqa: PERF203
78+
# Expected once close() is called
79+
if "closed MaxMind DB" not in str(e):
80+
errors.append(e)
81+
break
82+
except Exception as e: # noqa: BLE001
83+
errors.append(e)
84+
break
85+
86+
# Start multiple readers
87+
threads = [threading.Thread(target=continuous_reader) for _ in range(10)]
88+
for t in threads:
89+
t.start()
90+
91+
# Let readers run for a bit
92+
time.sleep(0.05)
93+
94+
# Close while reads are happening
95+
reader.close()
96+
97+
# Signal threads to stop
98+
should_stop.set()
99+
100+
# Wait for all threads
101+
for t in threads:
102+
t.join(timeout=1.0)
103+
104+
self.assertEqual(len(errors), 0, f"Errors during close test: {errors}")
105+
106+
def test_read_after_close(self) -> None:
107+
"""Test that reads after close raise appropriate error."""
108+
reader = open_database(
109+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
110+
MODE_MMAP_EXT,
111+
)
112+
reader.close()
113+
114+
with self.assertRaisesRegex(
115+
ValueError,
116+
"Attempt to read from a closed MaxMind DB",
117+
):
118+
reader.get("1.1.1.1")
119+
120+
def test_concurrent_reads_and_metadata(self) -> None:
121+
"""Test concurrent reads and metadata access."""
122+
reader = open_database(
123+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
124+
MODE_MMAP_EXT,
125+
)
126+
127+
errors: list[Exception] = []
128+
results: list[bool] = []
129+
130+
def do_reads() -> None:
131+
try:
132+
for _ in range(50):
133+
reader.get("1.1.1.1")
134+
results.append(True)
135+
except Exception as e: # noqa: BLE001
136+
errors.append(e)
137+
138+
def do_metadata() -> None:
139+
try:
140+
for _ in range(50):
141+
reader.metadata()
142+
results.append(True)
143+
except Exception as e: # noqa: BLE001
144+
errors.append(e)
145+
146+
threads = []
147+
for _ in range(5):
148+
threads.append(threading.Thread(target=do_reads))
149+
threads.append(threading.Thread(target=do_metadata))
150+
151+
for t in threads:
152+
t.start()
153+
154+
for t in threads:
155+
t.join()
156+
157+
reader.close()
158+
159+
self.assertEqual(
160+
len(errors), 0, f"Errors during concurrent operations: {errors}"
161+
)
162+
self.assertEqual(len(results), 10, "All threads should complete")
163+
164+
def test_concurrent_iteration(self) -> None:
165+
"""Test that iteration is thread-safe."""
166+
reader = open_database(
167+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
168+
MODE_MMAP_EXT,
169+
)
170+
171+
errors: list[Exception] = []
172+
counts: list[int] = []
173+
174+
def iterate() -> None:
175+
try:
176+
count = 0
177+
for _ in reader:
178+
count += 1
179+
counts.append(count)
180+
except Exception as e: # noqa: BLE001
181+
errors.append(e)
182+
183+
threads = [threading.Thread(target=iterate) for _ in range(10)]
184+
185+
for t in threads:
186+
t.start()
187+
188+
for t in threads:
189+
t.join()
190+
191+
reader.close()
192+
193+
self.assertEqual(len(errors), 0, f"Errors during iteration: {errors}")
194+
# All threads should see the same number of entries
195+
self.assertEqual(len(set(counts)), 1, "All threads should see same entry count")
196+
197+
def test_stress_test(self) -> None:
198+
"""Stress test with many threads and operations."""
199+
reader = open_database(
200+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
201+
MODE_MMAP_EXT,
202+
)
203+
204+
errors: list[Exception] = []
205+
operations_completed = threading.Event()
206+
207+
def random_operations() -> None:
208+
try:
209+
for i in range(100):
210+
# Mix different operations
211+
if i % 3 == 0:
212+
reader.get("1.1.1.1")
213+
elif i % 3 == 1:
214+
reader.metadata()
215+
else:
216+
reader.get_with_prefix_len("1.1.1.2")
217+
except Exception as e: # noqa: BLE001
218+
errors.append(e)
219+
220+
threads = [threading.Thread(target=random_operations) for _ in range(20)]
221+
222+
for t in threads:
223+
t.start()
224+
225+
for t in threads:
226+
t.join()
227+
228+
operations_completed.set()
229+
reader.close()
230+
231+
self.assertEqual(len(errors), 0, f"Errors during stress test: {errors}")
232+
233+
def test_multiple_readers_different_databases(self) -> None:
234+
"""Test multiple readers on different databases in parallel."""
235+
errors: list[Exception] = []
236+
237+
def use_reader(filename: str) -> None:
238+
try:
239+
reader = open_database(filename, MODE_MMAP_EXT)
240+
for _ in range(50):
241+
reader.get("1.1.1.1")
242+
reader.close()
243+
except Exception as e: # noqa: BLE001
244+
errors.append(e)
245+
246+
threads = [
247+
threading.Thread(
248+
target=use_reader,
249+
args=("tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",),
250+
)
251+
for _ in range(5)
252+
]
253+
threads.extend(
254+
[
255+
threading.Thread(
256+
target=use_reader,
257+
args=("tests/data/test-data/MaxMind-DB-test-ipv6-24.mmdb",),
258+
)
259+
for _ in range(5)
260+
]
261+
)
262+
263+
for t in threads:
264+
t.start()
265+
266+
for t in threads:
267+
t.join()
268+
269+
self.assertEqual(len(errors), 0, f"Errors with multiple readers: {errors}")
270+
271+
272+
if __name__ == "__main__":
273+
unittest.main()

0 commit comments

Comments
 (0)