|
15 | 15 | """Test the topology module."""
|
16 | 16 | from __future__ import annotations
|
17 | 17 |
|
| 18 | +import asyncio |
18 | 19 | import os
|
19 | 20 | import socketserver
|
20 | 21 | import sys
|
|
23 | 24 | sys.path[0:0] = [""]
|
24 | 25 |
|
25 | 26 | from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest
|
26 |
| -from test.pymongo_mocks import DummyMonitor |
| 27 | +from test.asynchronous.pymongo_mocks import DummyMonitor |
27 | 28 | from test.unified_format import generate_test_classes
|
28 | 29 | from test.utils import (
|
29 | 30 | CMAPListener,
|
|
59 | 60 | _IS_SYNC = False
|
60 | 61 |
|
61 | 62 | # Location of JSON test specifications.
|
62 |
| -SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring") |
| 63 | +if _IS_SYNC: |
| 64 | + SDAM_PATH = os.path.join( |
| 65 | + os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring" |
| 66 | + ) |
| 67 | +else: |
| 68 | + SDAM_PATH = os.path.join( |
| 69 | + os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir)), |
| 70 | + "discovery_and_monitoring", |
| 71 | + ) |
63 | 72 |
|
64 | 73 |
|
65 | 74 | async def create_mock_topology(uri, monitor_class=DummyMonitor):
|
@@ -92,7 +101,7 @@ async def got_hello(topology, server_address, hello_response):
|
92 | 101 | await topology.on_change(server_description)
|
93 | 102 |
|
94 | 103 |
|
95 |
| -def got_app_error(topology, app_error): |
| 104 | +async def got_app_error(topology, app_error): |
96 | 105 | server_address = common.partition_node(app_error["address"])
|
97 | 106 | server = topology.get_server_by_address(server_address)
|
98 | 107 | error_type = app_error["type"]
|
@@ -120,7 +129,7 @@ def got_app_error(topology, app_error):
|
120 | 129 | else:
|
121 | 130 | raise AssertionError(f"Unknown when field {when}")
|
122 | 131 |
|
123 |
| - topology.handle_error( |
| 132 | + await topology.handle_error( |
124 | 133 | server_address,
|
125 | 134 | _ErrorContext(e, max_wire_version, generation, completed_handshake, None),
|
126 | 135 | )
|
@@ -207,12 +216,14 @@ async def run_scenario(self):
|
207 | 216 | for i, phase in enumerate(scenario_def["phases"]):
|
208 | 217 | # Including the phase description makes failures easier to debug.
|
209 | 218 | description = phase.get("description", str(i))
|
| 219 | + if self._testMethodName == "test_single_direct_connection_external_ip": |
| 220 | + print("here") |
210 | 221 | with assertion_context(f"phase: {description}"):
|
211 | 222 | for response in phase.get("responses", []):
|
212 |
| - got_hello(c, common.partition_node(response[0]), response[1]) |
| 223 | + await got_hello(c, common.partition_node(response[0]), response[1]) |
213 | 224 |
|
214 | 225 | for app_error in phase.get("applicationErrors", []):
|
215 |
| - got_app_error(c, app_error) |
| 226 | + await got_app_error(c, app_error) |
216 | 227 |
|
217 | 228 | check_outcome(self, c, phase["outcome"])
|
218 | 229 |
|
@@ -369,7 +380,7 @@ def predicate():
|
369 | 380 | return False
|
370 | 381 | return True
|
371 | 382 |
|
372 |
| - wait_until(predicate, "find all RTT monitors") |
| 383 | + await async_wait_until(predicate, "find all RTT monitors") |
373 | 384 |
|
374 | 385 | async def test_rtt_connection_is_disabled_poll(self):
|
375 | 386 | client = await self.async_rs_or_single_client(serverMonitoringMode="poll")
|
|
0 commit comments