Skip to content

Commit 24d30af

Browse files
committed
address review
1 parent 47d9ebd commit 24d30af

File tree

5 files changed

+112
-160
lines changed

5 files changed

+112
-160
lines changed

test/asynchronous/helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,19 @@ async def run(self):
407407
await self.target(*self.args)
408408
finally:
409409
self.stopped = True
410+
411+
412+
def create_barrier(N_TASKS, timeout: float | None = None):
413+
return threading.Barrier(N_TASKS, timeout)
414+
415+
416+
def async_create_barrier(N_TASKS, timeout: float | None = None):
417+
return asyncio.Barrier(N_TASKS)
418+
419+
420+
def barrier_wait(barrier, timeout: float | None = None):
421+
barrier.wait()
422+
423+
424+
async def async_barrier_wait(barrier, timeout: float | None = None):
425+
await asyncio.wait_for(barrier.wait(), timeout)

test/asynchronous/test_discovery_and_monitoring.py

Lines changed: 39 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import threading
2323
from asyncio import StreamReader, StreamWriter
2424
from pathlib import Path
25+
from test.asynchronous.helpers import ConcurrentRunner, async_barrier_wait, async_create_barrier
2526

2627
sys.path[0:0] = [""]
2728

@@ -275,84 +276,44 @@ async def send_cluster_time(time, inc, should_update):
275276

276277

277278
class TestIgnoreStaleErrors(AsyncIntegrationTest):
278-
if _IS_SYNC:
279-
280-
async def test_ignore_stale_connection_errors(self):
281-
N_THREADS = 5
282-
barrier = threading.Barrier(N_THREADS, timeout=30)
283-
client = await self.async_rs_or_single_client(minPoolSize=N_THREADS)
284-
285-
# Wait for initial discovery.
286-
await client.admin.command("ping")
287-
pool = await async_get_pool(client)
288-
starting_generation = pool.gen.get_overall()
289-
await async_wait_until(lambda: len(pool.conns) == N_THREADS, "created conns")
290-
291-
def mock_command(*args, **kwargs):
292-
# Synchronize all threads to ensure they use the same generation.
293-
barrier.wait()
294-
raise AutoReconnect("mock AsyncConnection.command error")
295-
296-
for conn in pool.conns:
297-
conn.command = mock_command
298-
299-
async def insert_command(i):
300-
try:
301-
await client.test.command("insert", "test", documents=[{"i": i}])
302-
except AutoReconnect:
303-
pass
304-
305-
threads = []
306-
for i in range(N_THREADS):
307-
threads.append(threading.Thread(target=insert_command, args=(i,)))
308-
for t in threads:
309-
t.start()
310-
for t in threads:
311-
t.join()
312-
313-
# Expect a single pool reset for the network error
314-
self.assertEqual(starting_generation + 1, pool.gen.get_overall())
315-
316-
# Server should be selectable.
317-
await client.admin.command("ping")
318-
else:
319-
320-
async def test_ignore_stale_connection_errors(self):
321-
N_TASKS = 5
322-
barrier = asyncio.Barrier(N_TASKS)
323-
client = await self.async_rs_or_single_client(minPoolSize=N_TASKS)
324-
325-
# Wait for initial discovery.
326-
await client.admin.command("ping")
327-
pool = await async_get_pool(client)
328-
starting_generation = pool.gen.get_overall()
329-
await async_wait_until(lambda: len(pool.conns) == N_TASKS, "created conns")
330-
331-
async def mock_command(*args, **kwargs):
332-
# Synchronize all threads to ensure they use the same generation.
333-
await asyncio.wait_for(barrier.wait(), timeout=30)
334-
raise AutoReconnect("mock AsyncConnection.command error")
335-
336-
for conn in pool.conns:
337-
conn.command = mock_command
279+
async def test_ignore_stale_connection_errors(self):
280+
N_TASKS = 5
281+
barrier = async_create_barrier(N_TASKS, timeout=30)
282+
client = await self.async_rs_or_single_client(minPoolSize=N_TASKS)
338283

339-
async def insert_command(i):
340-
try:
341-
await client.test.command("insert", "test", documents=[{"i": i}])
342-
except AutoReconnect:
343-
pass
344-
345-
tasks = []
346-
for i in range(N_TASKS):
347-
tasks.append(asyncio.create_task(insert_command(i)))
348-
for t in tasks:
349-
await t
350-
351-
# Expect a single pool reset for the network error
352-
self.assertEqual(starting_generation + 1, pool.gen.get_overall())
353-
354-
# Server should be selectable.
355-
await client.admin.command("ping")
284+
# Wait for initial discovery.
285+
await client.admin.command("ping")
286+
pool = await async_get_pool(client)
287+
starting_generation = pool.gen.get_overall()
288+
await async_wait_until(lambda: len(pool.conns) == N_TASKS, "created conns")
289+
290+
async def mock_command(*args, **kwargs):
291+
# Synchronize all threads to ensure they use the same generation.
292+
await async_barrier_wait(barrier, timeout=30)
293+
raise AutoReconnect("mock AsyncConnection.command error")
294+
295+
for conn in pool.conns:
296+
conn.command = mock_command
297+
298+
async def insert_command(i):
299+
try:
300+
await client.test.command("insert", "test", documents=[{"i": i}])
301+
except AutoReconnect:
302+
pass
303+
304+
tasks = []
305+
for i in range(N_TASKS):
306+
tasks.append(ConcurrentRunner(target=insert_command, args=(i,)))
307+
for t in tasks:
308+
await t.start()
309+
for t in tasks:
310+
await t.join()
311+
312+
# Expect a single pool reset for the network error
313+
self.assertEqual(starting_generation + 1, pool.gen.get_overall())
314+
315+
# Server should be selectable.
316+
await client.admin.command("ping")
356317

357318

358319
class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener):
@@ -499,14 +460,12 @@ async def handle_client(reader: StreamReader, writer: StreamWriter):
499460
server = await asyncio.start_server(handle_client, "localhost", 9999)
500461
server.events = events
501462
await server.start_serving()
502-
print(server.is_serving())
503463
_c = self.simple_client(
504464
"mongodb://localhost:9999",
505465
serverSelectionTimeoutMS=500,
506466
event_listeners=(listener,),
507467
)
508-
if _c._options.connect:
509-
await _c.aconnect()
468+
await _c.aconnect()
510469

511470
await listener.async_wait_for_event(ServerHeartbeatStartedEvent, 1)
512471
await listener.async_wait_for_event(ServerHeartbeatFailedEvent, 1)

test/helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,19 @@ def run(self):
407407
self.target(*self.args)
408408
finally:
409409
self.stopped = True
410+
411+
412+
def create_barrier(N_TASKS, timeout: float | None = None):
413+
return threading.Barrier(N_TASKS, timeout)
414+
415+
416+
def create_barrier(N_TASKS, timeout: float | None = None):
417+
return asyncio.Barrier(N_TASKS)
418+
419+
420+
def barrier_wait(barrier, timeout: float | None = None):
421+
barrier.wait()
422+
423+
424+
def barrier_wait(barrier, timeout: float | None = None):
425+
asyncio.wait_for(barrier.wait(), timeout)

test/test_discovery_and_monitoring.py

Lines changed: 39 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import threading
2323
from asyncio import StreamReader, StreamWriter
2424
from pathlib import Path
25+
from test.helpers import ConcurrentRunner, barrier_wait, create_barrier
2526

2627
sys.path[0:0] = [""]
2728

@@ -275,84 +276,44 @@ def send_cluster_time(time, inc, should_update):
275276

276277

277278
class TestIgnoreStaleErrors(IntegrationTest):
278-
if _IS_SYNC:
279-
280-
def test_ignore_stale_connection_errors(self):
281-
N_THREADS = 5
282-
barrier = threading.Barrier(N_THREADS, timeout=30)
283-
client = self.rs_or_single_client(minPoolSize=N_THREADS)
284-
285-
# Wait for initial discovery.
286-
client.admin.command("ping")
287-
pool = get_pool(client)
288-
starting_generation = pool.gen.get_overall()
289-
wait_until(lambda: len(pool.conns) == N_THREADS, "created conns")
290-
291-
def mock_command(*args, **kwargs):
292-
# Synchronize all threads to ensure they use the same generation.
293-
barrier.wait()
294-
raise AutoReconnect("mock Connection.command error")
295-
296-
for conn in pool.conns:
297-
conn.command = mock_command
298-
299-
def insert_command(i):
300-
try:
301-
client.test.command("insert", "test", documents=[{"i": i}])
302-
except AutoReconnect:
303-
pass
304-
305-
threads = []
306-
for i in range(N_THREADS):
307-
threads.append(threading.Thread(target=insert_command, args=(i,)))
308-
for t in threads:
309-
t.start()
310-
for t in threads:
311-
t.join()
312-
313-
# Expect a single pool reset for the network error
314-
self.assertEqual(starting_generation + 1, pool.gen.get_overall())
315-
316-
# Server should be selectable.
317-
client.admin.command("ping")
318-
else:
319-
320-
def test_ignore_stale_connection_errors(self):
321-
N_TASKS = 5
322-
barrier = asyncio.Barrier(N_TASKS)
323-
client = self.rs_or_single_client(minPoolSize=N_TASKS)
324-
325-
# Wait for initial discovery.
326-
client.admin.command("ping")
327-
pool = get_pool(client)
328-
starting_generation = pool.gen.get_overall()
329-
wait_until(lambda: len(pool.conns) == N_TASKS, "created conns")
330-
331-
def mock_command(*args, **kwargs):
332-
# Synchronize all threads to ensure they use the same generation.
333-
asyncio.wait_for(barrier.wait(), timeout=30)
334-
raise AutoReconnect("mock Connection.command error")
335-
336-
for conn in pool.conns:
337-
conn.command = mock_command
279+
def test_ignore_stale_connection_errors(self):
280+
N_TASKS = 5
281+
barrier = create_barrier(N_TASKS, timeout=30)
282+
client = self.rs_or_single_client(minPoolSize=N_TASKS)
338283

339-
def insert_command(i):
340-
try:
341-
client.test.command("insert", "test", documents=[{"i": i}])
342-
except AutoReconnect:
343-
pass
344-
345-
tasks = []
346-
for i in range(N_TASKS):
347-
tasks.append(asyncio.create_task(insert_command(i)))
348-
for t in tasks:
349-
t
350-
351-
# Expect a single pool reset for the network error
352-
self.assertEqual(starting_generation + 1, pool.gen.get_overall())
353-
354-
# Server should be selectable.
355-
client.admin.command("ping")
284+
# Wait for initial discovery.
285+
client.admin.command("ping")
286+
pool = get_pool(client)
287+
starting_generation = pool.gen.get_overall()
288+
wait_until(lambda: len(pool.conns) == N_TASKS, "created conns")
289+
290+
def mock_command(*args, **kwargs):
291+
# Synchronize all threads to ensure they use the same generation.
292+
barrier_wait(barrier, timeout=30)
293+
raise AutoReconnect("mock Connection.command error")
294+
295+
for conn in pool.conns:
296+
conn.command = mock_command
297+
298+
def insert_command(i):
299+
try:
300+
client.test.command("insert", "test", documents=[{"i": i}])
301+
except AutoReconnect:
302+
pass
303+
304+
tasks = []
305+
for i in range(N_TASKS):
306+
tasks.append(ConcurrentRunner(target=insert_command, args=(i,)))
307+
for t in tasks:
308+
t.start()
309+
for t in tasks:
310+
t.join()
311+
312+
# Expect a single pool reset for the network error
313+
self.assertEqual(starting_generation + 1, pool.gen.get_overall())
314+
315+
# Server should be selectable.
316+
client.admin.command("ping")
356317

357318

358319
class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener):
@@ -499,14 +460,12 @@ def handle_client(reader: StreamReader, writer: StreamWriter):
499460
server = asyncio.start_server(handle_client, "localhost", 9999)
500461
server.events = events
501462
server.start_serving()
502-
print(server.is_serving())
503463
_c = self.simple_client(
504464
"mongodb://localhost:9999",
505465
serverSelectionTimeoutMS=500,
506466
event_listeners=(listener,),
507467
)
508-
if _c._options.connect:
509-
_c._connect()
468+
_c._connect()
510469

511470
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
512471
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)

tools/synchro.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@
122122
"SpecRunnerTask": "SpecRunnerThread",
123123
"AsyncMockConnection": "MockConnection",
124124
"AsyncMockPool": "MockPool",
125+
"async_create_barrier": "create_barrier",
126+
"async_barrier_wait": "barrier_wait",
125127
}
126128

127129
docstring_replacements: dict[tuple[str, str], str] = {

0 commit comments

Comments
 (0)