Skip to content

Commit 568cc64

Browse files
committed
debugging..
1 parent 70d09d3 commit 568cc64

File tree

3 files changed

+268
-131
lines changed

3 files changed

+268
-131
lines changed

test/asynchronous/test_discovery_and_monitoring.py

Lines changed: 134 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
import socketserver
2121
import sys
2222
import threading
23+
from asyncio import StreamReader
2324
from pathlib import Path
2425

2526
sys.path[0:0] = [""]
2627

27-
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest
28+
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, AsyncUnitTest, unittest
2829
from test.asynchronous.pymongo_mocks import DummyMonitor
2930
from test.asynchronous.unified_format import generate_test_classes
3031
from test.utils import (
@@ -226,7 +227,7 @@ async def run_scenario(self):
226227
return run_scenario
227228

228229

229-
def create_tests():
230+
async def create_tests():
230231
for dirpath, _, filenames in os.walk(SDAM_PATH):
231232
dirname = os.path.split(dirpath)[-1]
232233
# SDAM unified tests are handled separately.
@@ -247,7 +248,6 @@ def create_tests():
247248
setattr(TestAllScenarios, new_test.__name__, new_test)
248249

249250

250-
create_tests()
251251

252252

253253
class TestClusterTimeComparison(unittest.IsolatedAsyncioTestCase):
@@ -277,45 +277,82 @@ async def send_cluster_time(time, inc, should_update):
277277

278278

279279
class TestIgnoreStaleErrors(AsyncIntegrationTest):
280-
@async_client_context.require_sync
281-
async def test_ignore_stale_connection_errors(self):
282-
N_THREADS = 5
283-
barrier = threading.Barrier(N_THREADS, timeout=30)
284-
client = await self.async_rs_or_single_client(minPoolSize=N_THREADS)
280+
if _IS_SYNC:
281+
async def test_ignore_stale_connection_errors(self):
282+
N_THREADS = 5
283+
barrier = threading.Barrier(N_THREADS, timeout=30)
284+
client = await self.async_rs_or_single_client(minPoolSize=N_THREADS)
285+
286+
# Wait for initial discovery.
287+
await client.admin.command("ping")
288+
pool = await async_get_pool(client)
289+
starting_generation = pool.gen.get_overall()
290+
await async_wait_until(lambda: len(pool.conns) == N_THREADS, "created conns")
291+
292+
def mock_command(*args, **kwargs):
293+
# Synchronize all threads to ensure they use the same generation.
294+
barrier.wait()
295+
raise AutoReconnect("mock AsyncConnection.command error")
296+
297+
for conn in pool.conns:
298+
conn.command = mock_command
299+
300+
async def insert_command(i):
301+
try:
302+
await client.test.command("insert", "test", documents=[{"i": i}])
303+
except AutoReconnect:
304+
pass
305+
306+
threads = []
307+
for i in range(N_THREADS):
308+
threads.append(threading.Thread(target=insert_command, args=(i,)))
309+
for t in threads:
310+
t.start()
311+
for t in threads:
312+
t.join()
313+
314+
# Expect a single pool reset for the network error
315+
self.assertEqual(starting_generation + 1, pool.gen.get_overall())
316+
317+
# Server should be selectable.
318+
await client.admin.command("ping")
319+
else:
320+
async def test_ignore_stale_connection_errors(self):
321+
N_TASKS = 5
322+
barrier = asyncio.Barrier(N_TASKS)
323+
client = await self.async_rs_or_single_client(minPoolSize=N_TASKS)
285324

286-
# Wait for initial discovery.
287-
await client.admin.command("ping")
288-
pool = await async_get_pool(client)
289-
starting_generation = pool.gen.get_overall()
290-
await async_wait_until(lambda: len(pool.conns) == N_THREADS, "created conns")
291-
292-
def mock_command(*args, **kwargs):
293-
# Synchronize all threads to ensure they use the same generation.
294-
barrier.wait()
295-
raise AutoReconnect("mock AsyncConnection.command error")
296-
297-
for conn in pool.conns:
298-
conn.command = mock_command
299-
300-
async def insert_command(i):
301-
try:
302-
await client.test.command("insert", "test", documents=[{"i": i}])
303-
except AutoReconnect:
304-
pass
305-
306-
threads = []
307-
for i in range(N_THREADS):
308-
threads.append(threading.Thread(target=insert_command, args=(i,)))
309-
for t in threads:
310-
t.start()
311-
for t in threads:
312-
t.join()
313-
314-
# Expect a single pool reset for the network error
315-
self.assertEqual(starting_generation + 1, pool.gen.get_overall())
316-
317-
# Server should be selectable.
318-
await client.admin.command("ping")
325+
# Wait for initial discovery.
326+
await client.admin.command("ping")
327+
pool = await async_get_pool(client)
328+
starting_generation = pool.gen.get_overall()
329+
await async_wait_until(lambda: len(pool.conns) == N_TASKS, "created conns")
330+
331+
async def mock_command(*args, **kwargs):
332+
# Synchronize all threads to ensure they use the same generation.
333+
await asyncio.wait_for(barrier.wait(), timeout=30)
334+
raise AutoReconnect("mock AsyncConnection.command error")
335+
336+
for conn in pool.conns:
337+
conn.command = mock_command
338+
339+
async def insert_command(i):
340+
try:
341+
await client.test.command("insert", "test", documents=[{"i": i}])
342+
except AutoReconnect:
343+
pass
344+
345+
tasks = []
346+
for i in range(N_TASKS):
347+
tasks.append(asyncio.create_task(insert_command(i)))
348+
for t in tasks:
349+
await t
350+
351+
# Expect a single pool reset for the network error
352+
self.assertEqual(starting_generation + 1, pool.gen.get_overall())
353+
354+
# Server should be selectable.
355+
await client.admin.command("ping")
319356

320357

321358
class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener):
@@ -432,30 +469,62 @@ def handle_request_and_shutdown(self):
432469

433470

434471
class TestHeartbeatStartOrdering(AsyncPyMongoTestCase):
435-
@async_client_context.require_sync
436-
async def test_heartbeat_start_ordering(self):
437-
events = []
438-
listener = HeartbeatEventsListListener(events)
439-
server = TCPServer(("localhost", 9999), MockTCPHandler)
440-
server.events = events
441-
server_thread = threading.Thread(target=server.handle_request_and_shutdown)
442-
server_thread.start()
443-
_c = await self.simple_client(
444-
"mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,)
445-
)
446-
server_thread.join()
447-
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
448-
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
449-
450-
self.assertEqual(
451-
events,
452-
[
453-
"serverHeartbeatStartedEvent",
454-
"client connected",
455-
"client hello received",
456-
"serverHeartbeatFailedEvent",
457-
],
458-
)
472+
if _IS_SYNC:
473+
async def test_heartbeat_start_ordering(self):
474+
events = []
475+
listener = HeartbeatEventsListListener(events)
476+
server = TCPServer(("localhost", 9999), MockTCPHandler)
477+
server.events = events
478+
server_thread = threading.Thread(target=server.handle_request_and_shutdown)
479+
server_thread.start()
480+
_c = await self.simple_client(
481+
"mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,)
482+
)
483+
server_thread.join()
484+
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
485+
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
486+
487+
self.assertEqual(
488+
events,
489+
[
490+
"serverHeartbeatStartedEvent",
491+
"client connected",
492+
"client hello received",
493+
"serverHeartbeatFailedEvent",
494+
],
495+
)
496+
else:
497+
async def test_heartbeat_start_ordering(self):
498+
events = []
499+
500+
async def handle_client(reader: StreamReader, writer):
501+
server.events.append("client connected")
502+
print("clent connected")
503+
if (await reader.read(1024)).strip():
504+
server.events.append("client hello received")
505+
print("client helllo recieved")
506+
listener = HeartbeatEventsListListener(events)
507+
server = await asyncio.start_server(handle_client, "localhost", 9999)
508+
async with server:
509+
server.events = events
510+
_c = self.simple_client(
511+
"mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,)
512+
)
513+
server.close()
514+
server_task = asyncio.create_task(server.wait_closed())
515+
await server_task
516+
await listener.async_wait_for_event(ServerHeartbeatStartedEvent, 1)
517+
await listener.async_wait_for_event(ServerHeartbeatFailedEvent, 1)
518+
519+
self.assertEqual(
520+
events,
521+
[
522+
"serverHeartbeatStartedEvent",
523+
"client connected",
524+
"client hello received",
525+
"serverHeartbeatFailedEvent",
526+
],
527+
)
459528

460529

461530
# Generate unified tests.

0 commit comments

Comments
 (0)