Skip to content

Commit 7a17e8f

Browse files
committed
add missing awaits and changed path; some tests are failing but pushing what I have
1 parent 926c176 commit 7a17e8f

File tree

4 files changed

+35
-10
lines changed

4 files changed

+35
-10
lines changed

test/asynchronous/pymongo_mocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def open(self):
7575
def request_check(self):
7676
pass
7777

78-
def close(self):
78+
async def close(self):
7979
self.opened = False
8080

8181

test/asynchronous/test_discovery_and_monitoring.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Test the topology module."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import os
1920
import socketserver
2021
import sys
@@ -23,7 +24,7 @@
2324
sys.path[0:0] = [""]
2425

2526
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest
26-
from test.pymongo_mocks import DummyMonitor
27+
from test.asynchronous.pymongo_mocks import DummyMonitor
2728
from test.unified_format import generate_test_classes
2829
from test.utils import (
2930
CMAPListener,
@@ -59,7 +60,15 @@
5960
_IS_SYNC = False
6061

6162
# Location of JSON test specifications.
62-
SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring")
63+
if _IS_SYNC:
64+
SDAM_PATH = os.path.join(
65+
os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring"
66+
)
67+
else:
68+
SDAM_PATH = os.path.join(
69+
os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir)),
70+
"discovery_and_monitoring",
71+
)
6372

6473

6574
async def create_mock_topology(uri, monitor_class=DummyMonitor):
@@ -92,7 +101,7 @@ async def got_hello(topology, server_address, hello_response):
92101
await topology.on_change(server_description)
93102

94103

95-
def got_app_error(topology, app_error):
104+
async def got_app_error(topology, app_error):
96105
server_address = common.partition_node(app_error["address"])
97106
server = topology.get_server_by_address(server_address)
98107
error_type = app_error["type"]
@@ -120,7 +129,7 @@ def got_app_error(topology, app_error):
120129
else:
121130
raise AssertionError(f"Unknown when field {when}")
122131

123-
topology.handle_error(
132+
await topology.handle_error(
124133
server_address,
125134
_ErrorContext(e, max_wire_version, generation, completed_handshake, None),
126135
)
@@ -207,12 +216,14 @@ async def run_scenario(self):
207216
for i, phase in enumerate(scenario_def["phases"]):
208217
# Including the phase description makes failures easier to debug.
209218
description = phase.get("description", str(i))
219+
if self._testMethodName == "test_single_direct_connection_external_ip":
220+
print("here")
210221
with assertion_context(f"phase: {description}"):
211222
for response in phase.get("responses", []):
212-
got_hello(c, common.partition_node(response[0]), response[1])
223+
await got_hello(c, common.partition_node(response[0]), response[1])
213224

214225
for app_error in phase.get("applicationErrors", []):
215-
got_app_error(c, app_error)
226+
await got_app_error(c, app_error)
216227

217228
check_outcome(self, c, phase["outcome"])
218229

@@ -369,7 +380,7 @@ def predicate():
369380
return False
370381
return True
371382

372-
wait_until(predicate, "find all RTT monitors")
383+
await async_wait_until(predicate, "find all RTT monitors")
373384

374385
async def test_rtt_connection_is_disabled_poll(self):
375386
client = await self.async_rs_or_single_client(serverMonitoringMode="poll")

test/test_discovery_and_monitoring.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Test the topology module."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import os
1920
import socketserver
2021
import sys
@@ -58,7 +59,15 @@
5859
_IS_SYNC = True
5960

6061
# Location of JSON test specifications.
61-
SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring")
62+
if _IS_SYNC:
63+
SDAM_PATH = os.path.join(
64+
os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring"
65+
)
66+
else:
67+
SDAM_PATH = os.path.join(
68+
os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir)),
69+
"discovery_and_monitoring",
70+
)
6271

6372

6473
def create_mock_topology(uri, monitor_class=DummyMonitor):
@@ -206,6 +215,8 @@ def run_scenario(self):
206215
for i, phase in enumerate(scenario_def["phases"]):
207216
# Including the phase description makes failures easier to debug.
208217
description = phase.get("description", str(i))
218+
if self._testMethodName == "test_single_direct_connection_external_ip":
219+
print("here")
209220
with assertion_context(f"phase: {description}"):
210221
for response in phase.get("responses", []):
211222
got_hello(c, common.partition_node(response[0]), response[1])

test/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,10 @@ async def async_wait_until(predicate, success_description, timeout=10):
751751
start = time.time()
752752
interval = min(float(timeout) / 100, 0.1)
753753
while True:
754-
retval = await predicate()
754+
if iscoroutinefunction(predicate):
755+
retval = await predicate()
756+
else:
757+
retval = predicate()
755758
if retval:
756759
return retval
757760

0 commit comments

Comments
 (0)