Skip to content

Commit 7aa61e4

Browse files
committed
create async utils_selection_tests
1 parent 5970b55 commit 7aa61e4

File tree

3 files changed

+281
-10
lines changed

3 files changed

+281
-10
lines changed

test/asynchronous/test_server_selection_in_window.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
from pathlib import Path
2222
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
2323
from test.asynchronous.helpers import ConcurrentRunner
24+
from test.asynchronous.utils_selection_tests import create_topology
25+
from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator
2426
from test.utils import (
2527
CMAPListener,
2628
OvertCommandListener,
2729
async_get_pool,
2830
async_wait_until,
2931
)
30-
from test.utils_selection_tests import create_topology
31-
from test.utils_spec_runner import SpecTestCreator
3232

3333
from pymongo.common import clean_node
3434
from pymongo.monitoring import ConnectionReadyEvent
@@ -46,8 +46,8 @@
4646

4747

4848
class TestAllScenarios(unittest.IsolatedAsyncioTestCase):
49-
def run_scenario(self, scenario_def):
50-
topology = create_topology(scenario_def)
49+
async def run_scenario(self, scenario_def):
50+
topology = await create_topology(scenario_def)
5151

5252
# Update mock operation_count state:
5353
for mock in scenario_def["mocked_topology_state"]:
@@ -61,7 +61,7 @@ def run_scenario(self, scenario_def):
6161
# Number of times to repeat server selection
6262
iterations = scenario_def["iterations"]
6363
for _ in range(iterations):
64-
server = topology.select_server(pref, _Op.TEST, server_selection_timeout=0)
64+
server = await topology.select_server(pref, _Op.TEST, server_selection_timeout=0)
6565
counts[server.description.address] += 1
6666

6767
# Verify expected_frequencies
@@ -80,13 +80,13 @@ def run_scenario(self, scenario_def):
8080

8181

8282
def create_test(scenario_def, test, name):
83-
def run_scenario(self):
84-
self.run_scenario(scenario_def)
83+
async def run_scenario(self):
84+
await self.run_scenario(scenario_def)
8585

8686
return run_scenario
8787

8888

89-
class CustomSpecTestCreator(SpecTestCreator):
89+
class CustomSpecTestCreator(AsyncSpecTestCreator):
9090
def tests(self, scenario_def):
9191
"""Extract the tests from a spec file.
9292
@@ -164,7 +164,7 @@ async def test_load_balancing(self):
164164
"appName": "loadBalancingTest",
165165
},
166166
}
167-
with self.fail_point(delay_finds):
167+
async with self.fail_point(delay_finds):
168168
nodes = async_client_context.client.nodes
169169
self.assertEqual(len(nodes), 1)
170170
delayed_server = next(iter(nodes))
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright 2015-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utilities for testing Server Selection and Max Staleness."""
16+
from __future__ import annotations
17+
18+
import datetime
19+
import os
20+
import sys
21+
22+
sys.path[0:0] = [""]
23+
24+
from test import unittest
25+
from test.pymongo_mocks import DummyMonitor
26+
from test.utils import AsyncMockPool, parse_read_preference
27+
28+
from bson import json_util
29+
from pymongo.asynchronous.settings import TopologySettings
30+
from pymongo.asynchronous.topology import Topology
31+
from pymongo.common import HEARTBEAT_FREQUENCY, MIN_SUPPORTED_WIRE_VERSION, clean_node
32+
from pymongo.errors import AutoReconnect, ConfigurationError
33+
from pymongo.hello import Hello, HelloCompat
34+
from pymongo.operations import _Op
35+
from pymongo.server_description import ServerDescription
36+
from pymongo.server_selectors import writable_server_selector
37+
38+
39+
def get_addresses(server_list):
40+
seeds = []
41+
hosts = []
42+
for server in server_list:
43+
seeds.append(clean_node(server["address"]))
44+
hosts.append(server["address"])
45+
return seeds, hosts
46+
47+
48+
def make_last_write_date(server):
49+
epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None)
50+
millis = server.get("lastWrite", {}).get("lastWriteDate")
51+
if millis:
52+
diff = ((millis % 1000) + 1000) % 1000
53+
seconds = (millis - diff) / 1000
54+
micros = diff * 1000
55+
return epoch + datetime.timedelta(seconds=seconds, microseconds=micros)
56+
else:
57+
# "Unknown" server.
58+
return epoch
59+
60+
61+
def make_server_description(server, hosts):
62+
"""Make a ServerDescription from server info in a JSON test."""
63+
server_type = server["type"]
64+
if server_type in ("Unknown", "PossiblePrimary"):
65+
return ServerDescription(clean_node(server["address"]), Hello({}))
66+
67+
hello_response = {"ok": True, "hosts": hosts}
68+
if server_type not in ("Standalone", "Mongos", "RSGhost"):
69+
hello_response["setName"] = "rs"
70+
71+
if server_type == "RSPrimary":
72+
hello_response[HelloCompat.LEGACY_CMD] = True
73+
elif server_type == "RSSecondary":
74+
hello_response["secondary"] = True
75+
elif server_type == "Mongos":
76+
hello_response["msg"] = "isdbgrid"
77+
elif server_type == "RSGhost":
78+
hello_response["isreplicaset"] = True
79+
elif server_type == "RSArbiter":
80+
hello_response["arbiterOnly"] = True
81+
82+
hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)}
83+
84+
for field in "maxWireVersion", "tags", "idleWritePeriodMillis":
85+
if field in server:
86+
hello_response[field] = server[field]
87+
88+
hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION)
89+
90+
# Sets _last_update_time to now.
91+
sd = ServerDescription(
92+
clean_node(server["address"]),
93+
Hello(hello_response),
94+
round_trip_time=server["avg_rtt_ms"] / 1000.0,
95+
)
96+
97+
if "lastUpdateTime" in server:
98+
sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec.
99+
100+
return sd
101+
102+
103+
def get_topology_type_name(scenario_def):
104+
td = scenario_def["topology_description"]
105+
name = td["type"]
106+
if name == "Unknown":
107+
# PyMongo never starts a topology in type Unknown.
108+
return "Sharded" if len(td["servers"]) > 1 else "Single"
109+
else:
110+
return name
111+
112+
113+
def get_topology_settings_dict(**kwargs):
114+
settings = {
115+
"monitor_class": DummyMonitor,
116+
"heartbeat_frequency": HEARTBEAT_FREQUENCY,
117+
"pool_class": AsyncMockPool,
118+
}
119+
settings.update(kwargs)
120+
return settings
121+
122+
123+
async def create_topology(scenario_def, **kwargs):
124+
# Initialize topologies.
125+
if "heartbeatFrequencyMS" in scenario_def:
126+
frequency = int(scenario_def["heartbeatFrequencyMS"]) / 1000.0
127+
else:
128+
frequency = HEARTBEAT_FREQUENCY
129+
130+
seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"])
131+
132+
topology_type = get_topology_type_name(scenario_def)
133+
if topology_type == "LoadBalanced":
134+
kwargs.setdefault("load_balanced", True)
135+
# Force topology description to ReplicaSet
136+
elif topology_type in ["ReplicaSetNoPrimary", "ReplicaSetWithPrimary"]:
137+
kwargs.setdefault("replica_set_name", "rs")
138+
settings = get_topology_settings_dict(heartbeat_frequency=frequency, seeds=seeds, **kwargs)
139+
140+
# "Eligible servers" is defined in the server selection spec as
141+
# the set of servers matching both the ReadPreference's mode
142+
# and tag sets.
143+
topology = Topology(TopologySettings(**settings))
144+
await topology.open()
145+
146+
# Update topologies with server descriptions.
147+
for server in scenario_def["topology_description"]["servers"]:
148+
server_description = make_server_description(server, hosts)
149+
await topology.on_change(server_description)
150+
151+
# Assert that descriptions match
152+
assert (
153+
scenario_def["topology_description"]["type"] == topology.description.topology_type_name
154+
), topology.description.topology_type_name
155+
156+
return topology
157+
158+
159+
def create_test(scenario_def):
160+
async def run_scenario(self):
161+
_, hosts = get_addresses(scenario_def["topology_description"]["servers"])
162+
# "Eligible servers" is defined in the server selection spec as
163+
# the set of servers matching both the ReadPreference's mode
164+
# and tag sets.
165+
top_latency = await create_topology(scenario_def)
166+
167+
# "In latency window" is defined in the server selection
168+
# spec as the subset of suitable_servers that falls within the
169+
# allowable latency window.
170+
top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000)
171+
172+
# Create server selector.
173+
if scenario_def.get("operation") == "write":
174+
pref = writable_server_selector
175+
else:
176+
# Make first letter lowercase to match read_pref's modes.
177+
pref_def = scenario_def["read_preference"]
178+
if scenario_def.get("error"):
179+
with self.assertRaises((ConfigurationError, ValueError)):
180+
# Error can be raised when making Read Pref or selecting.
181+
pref = parse_read_preference(pref_def)
182+
await top_latency.select_server(pref, _Op.TEST)
183+
return
184+
185+
pref = parse_read_preference(pref_def)
186+
187+
# Select servers.
188+
if not scenario_def.get("suitable_servers"):
189+
with self.assertRaises(AutoReconnect):
190+
await top_suitable.select_server(pref, _Op.TEST, server_selection_timeout=0)
191+
192+
return
193+
194+
if not scenario_def["in_latency_window"]:
195+
with self.assertRaises(AutoReconnect):
196+
await top_latency.select_server(pref, _Op.TEST, server_selection_timeout=0)
197+
198+
return
199+
200+
actual_suitable_s = await top_suitable.select_servers(
201+
pref, _Op.TEST, server_selection_timeout=0
202+
)
203+
actual_latency_s = await top_latency.select_servers(
204+
pref, _Op.TEST, server_selection_timeout=0
205+
)
206+
207+
expected_suitable_servers = {}
208+
for server in scenario_def["suitable_servers"]:
209+
server_description = make_server_description(server, hosts)
210+
expected_suitable_servers[server["address"]] = server_description
211+
212+
actual_suitable_servers = {}
213+
for s in actual_suitable_s:
214+
actual_suitable_servers[
215+
"%s:%d" % (s.description.address[0], s.description.address[1])
216+
] = s.description
217+
218+
self.assertEqual(len(actual_suitable_servers), len(expected_suitable_servers))
219+
for k, actual in actual_suitable_servers.items():
220+
expected = expected_suitable_servers[k]
221+
self.assertEqual(expected.address, actual.address)
222+
self.assertEqual(expected.server_type, actual.server_type)
223+
self.assertEqual(expected.round_trip_time, actual.round_trip_time)
224+
self.assertEqual(expected.tags, actual.tags)
225+
self.assertEqual(expected.all_hosts, actual.all_hosts)
226+
227+
expected_latency_servers = {}
228+
for server in scenario_def["in_latency_window"]:
229+
server_description = make_server_description(server, hosts)
230+
expected_latency_servers[server["address"]] = server_description
231+
232+
actual_latency_servers = {}
233+
for s in actual_latency_s:
234+
actual_latency_servers[
235+
"%s:%d" % (s.description.address[0], s.description.address[1])
236+
] = s.description
237+
238+
self.assertEqual(len(actual_latency_servers), len(expected_latency_servers))
239+
for k, actual in actual_latency_servers.items():
240+
expected = expected_latency_servers[k]
241+
self.assertEqual(expected.address, actual.address)
242+
self.assertEqual(expected.server_type, actual.server_type)
243+
self.assertEqual(expected.round_trip_time, actual.round_trip_time)
244+
self.assertEqual(expected.tags, actual.tags)
245+
self.assertEqual(expected.all_hosts, actual.all_hosts)
246+
247+
return run_scenario
248+
249+
250+
def create_selection_tests(test_dir):
251+
class TestAllScenarios(unittest.TestCase):
252+
pass
253+
254+
for dirpath, _, filenames in os.walk(test_dir):
255+
dirname = os.path.split(dirpath)
256+
dirname = os.path.split(dirname[-2])[-1] + "_" + dirname[-1]
257+
258+
for filename in filenames:
259+
if os.path.splitext(filename)[1] != ".json":
260+
continue
261+
with open(os.path.join(dirpath, filename)) as scenario_stream:
262+
scenario_def = json_util.loads(scenario_stream.read())
263+
264+
# Construct test from scenario.
265+
new_test = create_test(scenario_def)
266+
test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}"
267+
268+
new_test.__name__ = test_name
269+
setattr(TestAllScenarios, new_test.__name__, new_test)
270+
271+
return TestAllScenarios

test/asynchronous/utils_spec_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ async def _create_tests(self):
229229
str(test_def["description"].replace(" ", "_").replace(".", "_")),
230230
)
231231

232-
new_test = await self._create_test(scenario_def, test_def, test_name)
232+
new_test = self._create_test(scenario_def, test_def, test_name)
233233
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
234234
new_test = self.ensure_run_on(scenario_def, new_test)
235235

0 commit comments

Comments
 (0)