Skip to content

Commit b22f1fa

Browse files
committed
Add threading test for C extension
We already have some threading coverage in the general reader tests, but this provides more detailed implementation tests for the extension.
1 parent 23f1114 commit b22f1fa

File tree

1 file changed

+277
-0
lines changed

1 file changed

+277
-0
lines changed

tests/threading_test.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
"""Tests for thread-safety and free-threading support."""
2+
3+
from __future__ import annotations
4+
5+
import threading
6+
import time
7+
import unittest
8+
9+
import maxminddb
10+
11+
try:
12+
import maxminddb.extension
13+
except ImportError:
14+
maxminddb.extension = None
15+
16+
from maxminddb import open_database
17+
from maxminddb.const import MODE_MMAP_EXT
18+
19+
20+
def has_maxminddb_extension() -> bool:
21+
return maxminddb.extension is not None and hasattr(
22+
maxminddb.extension,
23+
"Reader",
24+
)
25+
26+
27+
@unittest.skipIf(
28+
not has_maxminddb_extension(),
29+
"No C extension module found. Skipping threading tests",
30+
)
31+
class TestThreadSafety(unittest.TestCase):
32+
"""Test thread safety of the C extension."""
33+
34+
def test_concurrent_reads(self) -> None:
35+
"""Test multiple threads reading concurrently."""
36+
reader = open_database(
37+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
38+
MODE_MMAP_EXT,
39+
)
40+
41+
results: list[dict | None] = [None] * 100
42+
errors: list[Exception] = []
43+
44+
def lookup(index: int, ip: str) -> None:
45+
try:
46+
results[index] = reader.get(ip)
47+
except Exception as e: # noqa: BLE001
48+
errors.append(e)
49+
50+
threads = []
51+
for i in range(100):
52+
ip = f"1.1.1.{(i % 32) + 1}"
53+
t = threading.Thread(target=lookup, args=(i, ip))
54+
threads.append(t)
55+
t.start()
56+
57+
for t in threads:
58+
t.join()
59+
60+
reader.close()
61+
62+
self.assertEqual(len(errors), 0, f"Errors during concurrent reads: {errors}")
63+
# All lookups should have completed
64+
self.assertNotIn(None, results)
65+
66+
def test_read_during_close(self) -> None:
67+
"""Test that close is safe when reads are happening concurrently."""
68+
reader = open_database(
69+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
70+
MODE_MMAP_EXT,
71+
)
72+
73+
errors: list[Exception] = []
74+
should_stop = threading.Event()
75+
76+
def continuous_reader() -> None:
77+
# Keep reading until signaled to stop or reader is closed
78+
while not should_stop.is_set():
79+
try:
80+
reader.get("1.1.1.1")
81+
except ValueError as e: # noqa: PERF203
82+
# Expected once close() is called
83+
if "closed MaxMind DB" not in str(e):
84+
errors.append(e)
85+
break
86+
except Exception as e: # noqa: BLE001
87+
errors.append(e)
88+
break
89+
90+
# Start multiple readers
91+
threads = [threading.Thread(target=continuous_reader) for _ in range(10)]
92+
for t in threads:
93+
t.start()
94+
95+
# Let readers run for a bit
96+
time.sleep(0.05)
97+
98+
# Close while reads are happening
99+
reader.close()
100+
101+
# Signal threads to stop
102+
should_stop.set()
103+
104+
# Wait for all threads
105+
for t in threads:
106+
t.join(timeout=1.0)
107+
108+
self.assertEqual(len(errors), 0, f"Errors during close test: {errors}")
109+
110+
def test_read_after_close(self) -> None:
111+
"""Test that reads after close raise appropriate error."""
112+
reader = open_database(
113+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
114+
MODE_MMAP_EXT,
115+
)
116+
reader.close()
117+
118+
with self.assertRaisesRegex(
119+
ValueError,
120+
"Attempt to read from a closed MaxMind DB",
121+
):
122+
reader.get("1.1.1.1")
123+
124+
def test_concurrent_reads_and_metadata(self) -> None:
125+
"""Test concurrent reads and metadata access."""
126+
reader = open_database(
127+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
128+
MODE_MMAP_EXT,
129+
)
130+
131+
errors: list[Exception] = []
132+
results: list[bool] = []
133+
134+
def do_reads() -> None:
135+
try:
136+
for _ in range(50):
137+
reader.get("1.1.1.1")
138+
results.append(True)
139+
except Exception as e: # noqa: BLE001
140+
errors.append(e)
141+
142+
def do_metadata() -> None:
143+
try:
144+
for _ in range(50):
145+
reader.metadata()
146+
results.append(True)
147+
except Exception as e: # noqa: BLE001
148+
errors.append(e)
149+
150+
threads = []
151+
for _ in range(5):
152+
threads.append(threading.Thread(target=do_reads))
153+
threads.append(threading.Thread(target=do_metadata))
154+
155+
for t in threads:
156+
t.start()
157+
158+
for t in threads:
159+
t.join()
160+
161+
reader.close()
162+
163+
self.assertEqual(
164+
len(errors), 0, f"Errors during concurrent operations: {errors}"
165+
)
166+
self.assertEqual(len(results), 10, "All threads should complete")
167+
168+
def test_concurrent_iteration(self) -> None:
169+
"""Test that iteration is thread-safe."""
170+
reader = open_database(
171+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
172+
MODE_MMAP_EXT,
173+
)
174+
175+
errors: list[Exception] = []
176+
counts: list[int] = []
177+
178+
def iterate() -> None:
179+
try:
180+
count = 0
181+
for _ in reader:
182+
count += 1
183+
counts.append(count)
184+
except Exception as e: # noqa: BLE001
185+
errors.append(e)
186+
187+
threads = [threading.Thread(target=iterate) for _ in range(10)]
188+
189+
for t in threads:
190+
t.start()
191+
192+
for t in threads:
193+
t.join()
194+
195+
reader.close()
196+
197+
self.assertEqual(len(errors), 0, f"Errors during iteration: {errors}")
198+
# All threads should see the same number of entries
199+
self.assertEqual(len(set(counts)), 1, "All threads should see same entry count")
200+
201+
def test_stress_test(self) -> None:
202+
"""Stress test with many threads and operations."""
203+
reader = open_database(
204+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
205+
MODE_MMAP_EXT,
206+
)
207+
208+
errors: list[Exception] = []
209+
operations_completed = threading.Event()
210+
211+
def random_operations() -> None:
212+
try:
213+
for i in range(100):
214+
# Mix different operations
215+
if i % 3 == 0:
216+
reader.get("1.1.1.1")
217+
elif i % 3 == 1:
218+
reader.metadata()
219+
else:
220+
reader.get_with_prefix_len("1.1.1.2")
221+
except Exception as e: # noqa: BLE001
222+
errors.append(e)
223+
224+
threads = [threading.Thread(target=random_operations) for _ in range(20)]
225+
226+
for t in threads:
227+
t.start()
228+
229+
for t in threads:
230+
t.join()
231+
232+
operations_completed.set()
233+
reader.close()
234+
235+
self.assertEqual(len(errors), 0, f"Errors during stress test: {errors}")
236+
237+
def test_multiple_readers_different_databases(self) -> None:
238+
"""Test multiple readers on different databases in parallel."""
239+
errors: list[Exception] = []
240+
241+
def use_reader(filename: str) -> None:
242+
try:
243+
reader = open_database(filename, MODE_MMAP_EXT)
244+
for _ in range(50):
245+
reader.get("1.1.1.1")
246+
reader.close()
247+
except Exception as e: # noqa: BLE001
248+
errors.append(e)
249+
250+
threads = [
251+
threading.Thread(
252+
target=use_reader,
253+
args=("tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",),
254+
)
255+
for _ in range(5)
256+
]
257+
threads.extend(
258+
[
259+
threading.Thread(
260+
target=use_reader,
261+
args=("tests/data/test-data/MaxMind-DB-test-ipv6-24.mmdb",),
262+
)
263+
for _ in range(5)
264+
]
265+
)
266+
267+
for t in threads:
268+
t.start()
269+
270+
for t in threads:
271+
t.join()
272+
273+
self.assertEqual(len(errors), 0, f"Errors with multiple readers: {errors}")
274+
275+
276+
if __name__ == "__main__":
277+
unittest.main()

0 commit comments

Comments
 (0)