Skip to content

Commit f135c59

Browse files
committed
Add threading test for C extension
We already have some threading coverage in the general reader tests, but this provides more detailed implementation tests for the extension.
1 parent 9f15bec commit f135c59

File tree

1 file changed

+274
-0
lines changed

1 file changed

+274
-0
lines changed

tests/threading_test.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
"""Tests for thread-safety and free-threading support."""
2+
3+
from __future__ import annotations
4+
5+
import threading
6+
import time
7+
import unittest
8+
from typing import TYPE_CHECKING
9+
10+
if TYPE_CHECKING:
11+
from maxminddb.types import Record
12+
13+
try:
14+
import maxminddb.extension # noqa: F401
15+
16+
HAS_EXTENSION = True
17+
except ImportError:
18+
HAS_EXTENSION = False
19+
20+
from maxminddb import open_database
21+
from maxminddb.const import MODE_MMAP_EXT
22+
23+
24+
@unittest.skipIf(
25+
not HAS_EXTENSION,
26+
"No C extension module found. Skipping threading tests",
27+
)
28+
class TestThreadSafety(unittest.TestCase):
29+
"""Test thread safety of the C extension."""
30+
31+
def test_concurrent_reads(self) -> None:
32+
"""Test multiple threads reading concurrently."""
33+
reader = open_database(
34+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
35+
MODE_MMAP_EXT,
36+
)
37+
38+
results: list[Record | None] = [None] * 100
39+
errors: list[Exception] = []
40+
41+
def lookup(index: int, ip: str) -> None:
42+
try:
43+
results[index] = reader.get(ip)
44+
except Exception as e: # noqa: BLE001
45+
errors.append(e)
46+
47+
threads = []
48+
for i in range(100):
49+
ip = f"1.1.1.{(i % 32) + 1}"
50+
t = threading.Thread(target=lookup, args=(i, ip))
51+
threads.append(t)
52+
t.start()
53+
54+
for t in threads:
55+
t.join()
56+
57+
reader.close()
58+
59+
self.assertEqual(len(errors), 0, f"Errors during concurrent reads: {errors}")
60+
# All lookups should have completed
61+
self.assertNotIn(None, results)
62+
63+
def test_read_during_close(self) -> None:
64+
"""Test that close is safe when reads are happening concurrently."""
65+
reader = open_database(
66+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
67+
MODE_MMAP_EXT,
68+
)
69+
70+
errors: list[Exception] = []
71+
should_stop = threading.Event()
72+
73+
def continuous_reader() -> None:
74+
# Keep reading until signaled to stop or reader is closed
75+
while not should_stop.is_set():
76+
try:
77+
reader.get("1.1.1.1")
78+
except ValueError as e: # noqa: PERF203
79+
# Expected once close() is called
80+
if "closed MaxMind DB" not in str(e):
81+
errors.append(e)
82+
break
83+
except Exception as e: # noqa: BLE001
84+
errors.append(e)
85+
break
86+
87+
# Start multiple readers
88+
threads = [threading.Thread(target=continuous_reader) for _ in range(10)]
89+
for t in threads:
90+
t.start()
91+
92+
# Let readers run for a bit
93+
time.sleep(0.05)
94+
95+
# Close while reads are happening
96+
reader.close()
97+
98+
# Signal threads to stop
99+
should_stop.set()
100+
101+
# Wait for all threads
102+
for t in threads:
103+
t.join(timeout=1.0)
104+
105+
self.assertEqual(len(errors), 0, f"Errors during close test: {errors}")
106+
107+
def test_read_after_close(self) -> None:
108+
"""Test that reads after close raise appropriate error."""
109+
reader = open_database(
110+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
111+
MODE_MMAP_EXT,
112+
)
113+
reader.close()
114+
115+
with self.assertRaisesRegex(
116+
ValueError,
117+
"Attempt to read from a closed MaxMind DB",
118+
):
119+
reader.get("1.1.1.1")
120+
121+
def test_concurrent_reads_and_metadata(self) -> None:
122+
"""Test concurrent reads and metadata access."""
123+
reader = open_database(
124+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
125+
MODE_MMAP_EXT,
126+
)
127+
128+
errors: list[Exception] = []
129+
results: list[bool] = []
130+
131+
def do_reads() -> None:
132+
try:
133+
for _ in range(50):
134+
reader.get("1.1.1.1")
135+
results.append(True)
136+
except Exception as e: # noqa: BLE001
137+
errors.append(e)
138+
139+
def do_metadata() -> None:
140+
try:
141+
for _ in range(50):
142+
reader.metadata()
143+
results.append(True)
144+
except Exception as e: # noqa: BLE001
145+
errors.append(e)
146+
147+
threads = []
148+
for _ in range(5):
149+
threads.append(threading.Thread(target=do_reads))
150+
threads.append(threading.Thread(target=do_metadata))
151+
152+
for t in threads:
153+
t.start()
154+
155+
for t in threads:
156+
t.join()
157+
158+
reader.close()
159+
160+
self.assertEqual(
161+
len(errors), 0, f"Errors during concurrent operations: {errors}"
162+
)
163+
self.assertEqual(len(results), 10, "All threads should complete")
164+
165+
def test_concurrent_iteration(self) -> None:
166+
"""Test that iteration is thread-safe."""
167+
reader = open_database(
168+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
169+
MODE_MMAP_EXT,
170+
)
171+
172+
errors: list[Exception] = []
173+
counts: list[int] = []
174+
175+
def iterate() -> None:
176+
try:
177+
count = 0
178+
for _ in reader:
179+
count += 1
180+
counts.append(count)
181+
except Exception as e: # noqa: BLE001
182+
errors.append(e)
183+
184+
threads = [threading.Thread(target=iterate) for _ in range(10)]
185+
186+
for t in threads:
187+
t.start()
188+
189+
for t in threads:
190+
t.join()
191+
192+
reader.close()
193+
194+
self.assertEqual(len(errors), 0, f"Errors during iteration: {errors}")
195+
# All threads should see the same number of entries
196+
self.assertEqual(len(set(counts)), 1, "All threads should see same entry count")
197+
198+
def test_stress_test(self) -> None:
199+
"""Stress test with many threads and operations."""
200+
reader = open_database(
201+
"tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",
202+
MODE_MMAP_EXT,
203+
)
204+
205+
errors: list[Exception] = []
206+
operations_completed = threading.Event()
207+
208+
def random_operations() -> None:
209+
try:
210+
for i in range(100):
211+
# Mix different operations
212+
if i % 3 == 0:
213+
reader.get("1.1.1.1")
214+
elif i % 3 == 1:
215+
reader.metadata()
216+
else:
217+
reader.get_with_prefix_len("1.1.1.2")
218+
except Exception as e: # noqa: BLE001
219+
errors.append(e)
220+
221+
threads = [threading.Thread(target=random_operations) for _ in range(20)]
222+
223+
for t in threads:
224+
t.start()
225+
226+
for t in threads:
227+
t.join()
228+
229+
operations_completed.set()
230+
reader.close()
231+
232+
self.assertEqual(len(errors), 0, f"Errors during stress test: {errors}")
233+
234+
def test_multiple_readers_different_databases(self) -> None:
235+
"""Test multiple readers on different databases in parallel."""
236+
errors: list[Exception] = []
237+
238+
def use_reader(filename: str) -> None:
239+
try:
240+
reader = open_database(filename, MODE_MMAP_EXT)
241+
for _ in range(50):
242+
reader.get("1.1.1.1")
243+
reader.close()
244+
except Exception as e: # noqa: BLE001
245+
errors.append(e)
246+
247+
threads = [
248+
threading.Thread(
249+
target=use_reader,
250+
args=("tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb",),
251+
)
252+
for _ in range(5)
253+
]
254+
threads.extend(
255+
[
256+
threading.Thread(
257+
target=use_reader,
258+
args=("tests/data/test-data/MaxMind-DB-test-ipv6-24.mmdb",),
259+
)
260+
for _ in range(5)
261+
]
262+
)
263+
264+
for t in threads:
265+
t.start()
266+
267+
for t in threads:
268+
t.join()
269+
270+
self.assertEqual(len(errors), 0, f"Errors with multiple readers: {errors}")
271+
272+
273+
if __name__ == "__main__":
274+
unittest.main()

0 commit comments

Comments
 (0)