Skip to content

Commit 47d9ebd

Browse files
committed
make test_heartbeat_start_ordering async
1 parent d7a4a28 commit 47d9ebd

File tree

3 files changed

+95
-97
lines changed

3 files changed

+95
-97
lines changed

test/asynchronous/test_discovery_and_monitoring.py

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import socketserver
2121
import sys
2222
import threading
23-
from asyncio import StreamReader
23+
from asyncio import StreamReader, StreamWriter
2424
from pathlib import Path
2525

2626
sys.path[0:0] = [""]
@@ -227,7 +227,7 @@ async def run_scenario(self):
227227
return run_scenario
228228

229229

230-
async def create_tests():
230+
def create_tests():
231231
for dirpath, _, filenames in os.walk(SDAM_PATH):
232232
dirname = os.path.split(dirpath)[-1]
233233
# SDAM unified tests are handled separately.
@@ -248,8 +248,6 @@ async def create_tests():
248248
setattr(TestAllScenarios, new_test.__name__, new_test)
249249

250250

251-
252-
253251
class TestClusterTimeComparison(unittest.IsolatedAsyncioTestCase):
254252
async def test_cluster_time_comparison(self):
255253
t = await create_mock_topology("mongodb://host")
@@ -278,6 +276,7 @@ async def send_cluster_time(time, inc, should_update):
278276

279277
class TestIgnoreStaleErrors(AsyncIntegrationTest):
280278
if _IS_SYNC:
279+
281280
async def test_ignore_stale_connection_errors(self):
282281
N_THREADS = 5
283282
barrier = threading.Barrier(N_THREADS, timeout=30)
@@ -317,6 +316,7 @@ async def insert_command(i):
317316
# Server should be selectable.
318317
await client.admin.command("ping")
319318
else:
319+
320320
async def test_ignore_stale_connection_errors(self):
321321
N_TASKS = 5
322322
barrier = asyncio.Barrier(N_TASKS)
@@ -469,62 +469,61 @@ def handle_request_and_shutdown(self):
469469

470470

471471
class TestHeartbeatStartOrdering(AsyncPyMongoTestCase):
472-
if _IS_SYNC:
473-
async def test_heartbeat_start_ordering(self):
474-
events = []
475-
listener = HeartbeatEventsListListener(events)
472+
async def test_heartbeat_start_ordering(self):
473+
events = []
474+
listener = HeartbeatEventsListListener(events)
475+
476+
if _IS_SYNC:
476477
server = TCPServer(("localhost", 9999), MockTCPHandler)
477478
server.events = events
478479
server_thread = threading.Thread(target=server.handle_request_and_shutdown)
479480
server_thread.start()
480481
_c = await self.simple_client(
481-
"mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,)
482+
"mongodb://localhost:9999",
483+
serverSelectionTimeoutMS=500,
484+
event_listeners=(listener,),
482485
)
483486
server_thread.join()
484487
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
485488
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
486489

487-
self.assertEqual(
488-
events,
489-
[
490-
"serverHeartbeatStartedEvent",
491-
"client connected",
492-
"client hello received",
493-
"serverHeartbeatFailedEvent",
494-
],
495-
)
496-
else:
497-
async def test_heartbeat_start_ordering(self):
498-
events = []
490+
else:
499491

500-
async def handle_client(reader: StreamReader, writer):
501-
server.events.append("client connected")
502-
print("clent connected")
492+
async def handle_client(reader: StreamReader, writer: StreamWriter):
493+
events.append("client connected")
503494
if (await reader.read(1024)).strip():
504-
server.events.append("client hello received")
505-
print("client helllo recieved")
506-
listener = HeartbeatEventsListListener(events)
495+
events.append("client hello received")
496+
writer.close()
497+
await writer.wait_closed()
498+
507499
server = await asyncio.start_server(handle_client, "localhost", 9999)
508-
async with server:
509-
server.events = events
510-
_c = self.simple_client(
511-
"mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,)
512-
)
513-
server.close()
514-
server_task = asyncio.create_task(server.wait_closed())
515-
await server_task
516-
await listener.async_wait_for_event(ServerHeartbeatStartedEvent, 1)
517-
await listener.async_wait_for_event(ServerHeartbeatFailedEvent, 1)
518-
519-
self.assertEqual(
520-
events,
521-
[
522-
"serverHeartbeatStartedEvent",
523-
"client connected",
524-
"client hello received",
525-
"serverHeartbeatFailedEvent",
526-
],
527-
)
500+
server.events = events
501+
await server.start_serving()
502+
print(server.is_serving())
503+
_c = self.simple_client(
504+
"mongodb://localhost:9999",
505+
serverSelectionTimeoutMS=500,
506+
event_listeners=(listener,),
507+
)
508+
if _c._options.connect:
509+
await _c.aconnect()
510+
511+
await listener.async_wait_for_event(ServerHeartbeatStartedEvent, 1)
512+
await listener.async_wait_for_event(ServerHeartbeatFailedEvent, 1)
513+
514+
server.close()
515+
await server.wait_closed()
516+
await _c.close()
517+
518+
self.assertEqual(
519+
events,
520+
[
521+
"serverHeartbeatStartedEvent",
522+
"client connected",
523+
"client hello received",
524+
"serverHeartbeatFailedEvent",
525+
],
526+
)
528527

529528

530529
# Generate unified tests.

test/asynchronous/unified_format.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,7 +1155,7 @@ def _testOperation_assertTopologyType(self, spec):
11551155
self.assertIsInstance(description, TopologyDescription)
11561156
self.assertEqual(description.topology_type_name, spec["topologyType"])
11571157

1158-
def _testOperation_waitForPrimaryChange(self, spec: dict) -> None:
1158+
async def _testOperation_waitForPrimaryChange(self, spec: dict) -> None:
11591159
"""Run the waitForPrimaryChange test operation."""
11601160
client = self.entity_map[spec["client"]]
11611161
old_description: TopologyDescription = self.entity_map[spec["priorTopologyDescription"]]
@@ -1169,13 +1169,13 @@ def get_primary(td: TopologyDescription) -> Optional[_Address]:
11691169

11701170
old_primary = get_primary(old_description)
11711171

1172-
def primary_changed() -> bool:
1173-
primary = client.primary
1172+
async def primary_changed() -> bool:
1173+
primary = await client.primary
11741174
if primary is None:
11751175
return False
11761176
return primary != old_primary
11771177

1178-
wait_until(primary_changed, "change primary", timeout=timeout)
1178+
await async_wait_until(primary_changed, "change primary", timeout=timeout)
11791179

11801180
async def _testOperation_runOnThread(self, spec):
11811181
"""Run the 'runOnThread' operation."""

test/test_discovery_and_monitoring.py

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import socketserver
2121
import sys
2222
import threading
23-
from asyncio import StreamReader
23+
from asyncio import StreamReader, StreamWriter
2424
from pathlib import Path
2525

2626
sys.path[0:0] = [""]
@@ -248,8 +248,6 @@ def create_tests():
248248
setattr(TestAllScenarios, new_test.__name__, new_test)
249249

250250

251-
252-
253251
class TestClusterTimeComparison(unittest.TestCase):
254252
def test_cluster_time_comparison(self):
255253
t = create_mock_topology("mongodb://host")
@@ -278,6 +276,7 @@ def send_cluster_time(time, inc, should_update):
278276

279277
class TestIgnoreStaleErrors(IntegrationTest):
280278
if _IS_SYNC:
279+
281280
def test_ignore_stale_connection_errors(self):
282281
N_THREADS = 5
283282
barrier = threading.Barrier(N_THREADS, timeout=30)
@@ -317,6 +316,7 @@ def insert_command(i):
317316
# Server should be selectable.
318317
client.admin.command("ping")
319318
else:
319+
320320
def test_ignore_stale_connection_errors(self):
321321
N_TASKS = 5
322322
barrier = asyncio.Barrier(N_TASKS)
@@ -469,62 +469,61 @@ def handle_request_and_shutdown(self):
469469

470470

471471
class TestHeartbeatStartOrdering(PyMongoTestCase):
472-
if _IS_SYNC:
473-
def test_heartbeat_start_ordering(self):
474-
events = []
475-
listener = HeartbeatEventsListListener(events)
472+
def test_heartbeat_start_ordering(self):
473+
events = []
474+
listener = HeartbeatEventsListListener(events)
475+
476+
if _IS_SYNC:
476477
server = TCPServer(("localhost", 9999), MockTCPHandler)
477478
server.events = events
478479
server_thread = threading.Thread(target=server.handle_request_and_shutdown)
479480
server_thread.start()
480481
_c = self.simple_client(
481-
"mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,)
482+
"mongodb://localhost:9999",
483+
serverSelectionTimeoutMS=500,
484+
event_listeners=(listener,),
482485
)
483486
server_thread.join()
484487
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
485488
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
486489

487-
self.assertEqual(
488-
events,
489-
[
490-
"serverHeartbeatStartedEvent",
491-
"client connected",
492-
"client hello received",
493-
"serverHeartbeatFailedEvent",
494-
],
495-
)
496-
else:
497-
def test_heartbeat_start_ordering(self):
498-
events = []
490+
else:
499491

500-
def handle_client(reader: StreamReader, writer):
501-
server.events.append("client connected")
502-
print("clent connected")
492+
def handle_client(reader: StreamReader, writer: StreamWriter):
493+
events.append("client connected")
503494
if (reader.read(1024)).strip():
504-
server.events.append("client hello received")
505-
print("client helllo recieved")
506-
listener = HeartbeatEventsListListener(events)
495+
events.append("client hello received")
496+
writer.close()
497+
writer.wait_closed()
498+
507499
server = asyncio.start_server(handle_client, "localhost", 9999)
508-
with server:
509-
server.events = events
510-
_c = self.simple_client(
511-
"mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,)
512-
)
513-
server.close()
514-
server_task = asyncio.create_task(server.wait_closed())
515-
server_task
516-
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
517-
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
518-
519-
self.assertEqual(
520-
events,
521-
[
522-
"serverHeartbeatStartedEvent",
523-
"client connected",
524-
"client hello received",
525-
"serverHeartbeatFailedEvent",
526-
],
527-
)
500+
server.events = events
501+
server.start_serving()
502+
print(server.is_serving())
503+
_c = self.simple_client(
504+
"mongodb://localhost:9999",
505+
serverSelectionTimeoutMS=500,
506+
event_listeners=(listener,),
507+
)
508+
if _c._options.connect:
509+
_c._connect()
510+
511+
listener.wait_for_event(ServerHeartbeatStartedEvent, 1)
512+
listener.wait_for_event(ServerHeartbeatFailedEvent, 1)
513+
514+
server.close()
515+
server.wait_closed()
516+
_c.close()
517+
518+
self.assertEqual(
519+
events,
520+
[
521+
"serverHeartbeatStartedEvent",
522+
"client connected",
523+
"client hello received",
524+
"serverHeartbeatFailedEvent",
525+
],
526+
)
528527

529528

530529
# Generate unified tests.

0 commit comments

Comments
 (0)