20
20
import socketserver
21
21
import sys
22
22
import threading
23
+ from asyncio import StreamReader
23
24
from pathlib import Path
24
25
25
26
sys .path [0 :0 ] = ["" ]
26
27
27
- from test .asynchronous import AsyncIntegrationTest , AsyncPyMongoTestCase , unittest
28
+ from test .asynchronous import AsyncIntegrationTest , AsyncPyMongoTestCase , AsyncUnitTest , unittest
28
29
from test .asynchronous .pymongo_mocks import DummyMonitor
29
30
from test .asynchronous .unified_format import generate_test_classes
30
31
from test .utils import (
@@ -226,7 +227,7 @@ async def run_scenario(self):
226
227
return run_scenario
227
228
228
229
229
- def create_tests ():
230
+ async def create_tests ():
230
231
for dirpath , _ , filenames in os .walk (SDAM_PATH ):
231
232
dirname = os .path .split (dirpath )[- 1 ]
232
233
# SDAM unified tests are handled separately.
@@ -247,7 +248,6 @@ def create_tests():
247
248
setattr (TestAllScenarios , new_test .__name__ , new_test )
248
249
249
250
250
- create_tests ()
251
251
252
252
253
253
class TestClusterTimeComparison (unittest .IsolatedAsyncioTestCase ):
@@ -277,45 +277,82 @@ async def send_cluster_time(time, inc, should_update):
277
277
278
278
279
279
class TestIgnoreStaleErrors (AsyncIntegrationTest ):
280
- @async_client_context .require_sync
281
- async def test_ignore_stale_connection_errors (self ):
282
- N_THREADS = 5
283
- barrier = threading .Barrier (N_THREADS , timeout = 30 )
284
- client = await self .async_rs_or_single_client (minPoolSize = N_THREADS )
280
+ if _IS_SYNC :
281
+ async def test_ignore_stale_connection_errors (self ):
282
+ N_THREADS = 5
283
+ barrier = threading .Barrier (N_THREADS , timeout = 30 )
284
+ client = await self .async_rs_or_single_client (minPoolSize = N_THREADS )
285
+
286
+ # Wait for initial discovery.
287
+ await client .admin .command ("ping" )
288
+ pool = await async_get_pool (client )
289
+ starting_generation = pool .gen .get_overall ()
290
+ await async_wait_until (lambda : len (pool .conns ) == N_THREADS , "created conns" )
291
+
292
+ def mock_command (* args , ** kwargs ):
293
+ # Synchronize all threads to ensure they use the same generation.
294
+ barrier .wait ()
295
+ raise AutoReconnect ("mock AsyncConnection.command error" )
296
+
297
+ for conn in pool .conns :
298
+ conn .command = mock_command
299
+
300
+ async def insert_command (i ):
301
+ try :
302
+ await client .test .command ("insert" , "test" , documents = [{"i" : i }])
303
+ except AutoReconnect :
304
+ pass
305
+
306
+ threads = []
307
+ for i in range (N_THREADS ):
308
+ threads .append (threading .Thread (target = insert_command , args = (i ,)))
309
+ for t in threads :
310
+ t .start ()
311
+ for t in threads :
312
+ t .join ()
313
+
314
+ # Expect a single pool reset for the network error
315
+ self .assertEqual (starting_generation + 1 , pool .gen .get_overall ())
316
+
317
+ # Server should be selectable.
318
+ await client .admin .command ("ping" )
319
+ else :
320
+ async def test_ignore_stale_connection_errors (self ):
321
+ N_TASKS = 5
322
+ barrier = asyncio .Barrier (N_TASKS )
323
+ client = await self .async_rs_or_single_client (minPoolSize = N_TASKS )
285
324
286
- # Wait for initial discovery.
287
- await client .admin .command ("ping" )
288
- pool = await async_get_pool (client )
289
- starting_generation = pool .gen .get_overall ()
290
- await async_wait_until (lambda : len (pool .conns ) == N_THREADS , "created conns" )
291
-
292
- def mock_command (* args , ** kwargs ):
293
- # Synchronize all threads to ensure they use the same generation.
294
- barrier .wait ()
295
- raise AutoReconnect ("mock AsyncConnection.command error" )
296
-
297
- for conn in pool .conns :
298
- conn .command = mock_command
299
-
300
- async def insert_command (i ):
301
- try :
302
- await client .test .command ("insert" , "test" , documents = [{"i" : i }])
303
- except AutoReconnect :
304
- pass
305
-
306
- threads = []
307
- for i in range (N_THREADS ):
308
- threads .append (threading .Thread (target = insert_command , args = (i ,)))
309
- for t in threads :
310
- t .start ()
311
- for t in threads :
312
- t .join ()
313
-
314
- # Expect a single pool reset for the network error
315
- self .assertEqual (starting_generation + 1 , pool .gen .get_overall ())
316
-
317
- # Server should be selectable.
318
- await client .admin .command ("ping" )
325
+ # Wait for initial discovery.
326
+ await client .admin .command ("ping" )
327
+ pool = await async_get_pool (client )
328
+ starting_generation = pool .gen .get_overall ()
329
+ await async_wait_until (lambda : len (pool .conns ) == N_TASKS , "created conns" )
330
+
331
+ async def mock_command (* args , ** kwargs ):
332
+ # Synchronize all threads to ensure they use the same generation.
333
+ await asyncio .wait_for (barrier .wait (), timeout = 30 )
334
+ raise AutoReconnect ("mock AsyncConnection.command error" )
335
+
336
+ for conn in pool .conns :
337
+ conn .command = mock_command
338
+
339
+ async def insert_command (i ):
340
+ try :
341
+ await client .test .command ("insert" , "test" , documents = [{"i" : i }])
342
+ except AutoReconnect :
343
+ pass
344
+
345
+ tasks = []
346
+ for i in range (N_TASKS ):
347
+ tasks .append (asyncio .create_task (insert_command (i )))
348
+ for t in tasks :
349
+ await t
350
+
351
+ # Expect a single pool reset for the network error
352
+ self .assertEqual (starting_generation + 1 , pool .gen .get_overall ())
353
+
354
+ # Server should be selectable.
355
+ await client .admin .command ("ping" )
319
356
320
357
321
358
class CMAPHeartbeatListener (HeartbeatEventListener , CMAPListener ):
@@ -432,30 +469,62 @@ def handle_request_and_shutdown(self):
432
469
433
470
434
471
class TestHeartbeatStartOrdering (AsyncPyMongoTestCase ):
435
- @async_client_context .require_sync
436
- async def test_heartbeat_start_ordering (self ):
437
- events = []
438
- listener = HeartbeatEventsListListener (events )
439
- server = TCPServer (("localhost" , 9999 ), MockTCPHandler )
440
- server .events = events
441
- server_thread = threading .Thread (target = server .handle_request_and_shutdown )
442
- server_thread .start ()
443
- _c = await self .simple_client (
444
- "mongodb://localhost:9999" , serverSelectionTimeoutMS = 500 , event_listeners = (listener ,)
445
- )
446
- server_thread .join ()
447
- listener .wait_for_event (ServerHeartbeatStartedEvent , 1 )
448
- listener .wait_for_event (ServerHeartbeatFailedEvent , 1 )
449
-
450
- self .assertEqual (
451
- events ,
452
- [
453
- "serverHeartbeatStartedEvent" ,
454
- "client connected" ,
455
- "client hello received" ,
456
- "serverHeartbeatFailedEvent" ,
457
- ],
458
- )
472
+ if _IS_SYNC :
473
+ async def test_heartbeat_start_ordering (self ):
474
+ events = []
475
+ listener = HeartbeatEventsListListener (events )
476
+ server = TCPServer (("localhost" , 9999 ), MockTCPHandler )
477
+ server .events = events
478
+ server_thread = threading .Thread (target = server .handle_request_and_shutdown )
479
+ server_thread .start ()
480
+ _c = await self .simple_client (
481
+ "mongodb://localhost:9999" , serverSelectionTimeoutMS = 500 , event_listeners = (listener ,)
482
+ )
483
+ server_thread .join ()
484
+ listener .wait_for_event (ServerHeartbeatStartedEvent , 1 )
485
+ listener .wait_for_event (ServerHeartbeatFailedEvent , 1 )
486
+
487
+ self .assertEqual (
488
+ events ,
489
+ [
490
+ "serverHeartbeatStartedEvent" ,
491
+ "client connected" ,
492
+ "client hello received" ,
493
+ "serverHeartbeatFailedEvent" ,
494
+ ],
495
+ )
496
+ else :
497
+ async def test_heartbeat_start_ordering (self ):
498
+ events = []
499
+
500
+ async def handle_client (reader : StreamReader , writer ):
501
+ server .events .append ("client connected" )
502
+ print ("clent connected" )
503
+ if (await reader .read (1024 )).strip ():
504
+ server .events .append ("client hello received" )
505
+ print ("client helllo recieved" )
506
+ listener = HeartbeatEventsListListener (events )
507
+ server = await asyncio .start_server (handle_client , "localhost" , 9999 )
508
+ async with server :
509
+ server .events = events
510
+ _c = self .simple_client (
511
+ "mongodb://localhost:9999" , serverSelectionTimeoutMS = 500 , event_listeners = (listener ,)
512
+ )
513
+ server .close ()
514
+ server_task = asyncio .create_task (server .wait_closed ())
515
+ await server_task
516
+ await listener .async_wait_for_event (ServerHeartbeatStartedEvent , 1 )
517
+ await listener .async_wait_for_event (ServerHeartbeatFailedEvent , 1 )
518
+
519
+ self .assertEqual (
520
+ events ,
521
+ [
522
+ "serverHeartbeatStartedEvent" ,
523
+ "client connected" ,
524
+ "client hello received" ,
525
+ "serverHeartbeatFailedEvent" ,
526
+ ],
527
+ )
459
528
460
529
461
530
# Generate unified tests.
0 commit comments