Skip to content

Commit f73e0d2

Browse files
committed
Convert test.test_server_selection_in_window to async
1 parent 6b141d1 commit f73e0d2

File tree

3 files changed

+234
-13
lines changed

3 files changed

+234
-13
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2020-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test the topology module's Server Selection Spec implementation."""
16+
from __future__ import annotations
17+
18+
import asyncio
19+
import os
20+
import threading
21+
from pathlib import Path
22+
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
23+
from test.utils import (
24+
CMAPListener,
25+
OvertCommandListener,
26+
async_get_pool,
27+
async_wait_until,
28+
)
29+
from test.utils_selection_tests import create_topology
30+
from test.utils_spec_runner import SpecTestCreator
31+
32+
from pymongo.common import clean_node
33+
from pymongo.monitoring import ConnectionReadyEvent
34+
from pymongo.operations import _Op
35+
from pymongo.read_preferences import ReadPreference
36+
37+
_IS_SYNC = False
38+
# Location of JSON test specifications.
39+
if _IS_SYNC:
40+
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window")
41+
else:
42+
TEST_PATH = os.path.join(
43+
Path(__file__).resolve().parent.parent, "server_selection", "in_window"
44+
)
45+
46+
47+
class TestAllScenarios(unittest.IsolatedAsyncioTestCase):
48+
def run_scenario(self, scenario_def):
49+
topology = create_topology(scenario_def)
50+
51+
# Update mock operation_count state:
52+
for mock in scenario_def["mocked_topology_state"]:
53+
address = clean_node(mock["address"])
54+
server = topology.get_server_by_address(address)
55+
server.pool.operation_count = mock["operation_count"]
56+
57+
pref = ReadPreference.NEAREST
58+
counts = {address: 0 for address in topology._description.server_descriptions()}
59+
60+
# Number of times to repeat server selection
61+
iterations = scenario_def["iterations"]
62+
for _ in range(iterations):
63+
server = topology.select_server(pref, _Op.TEST, server_selection_timeout=0)
64+
counts[server.description.address] += 1
65+
66+
# Verify expected_frequencies
67+
outcome = scenario_def["outcome"]
68+
tolerance = outcome["tolerance"]
69+
expected_frequencies = outcome["expected_frequencies"]
70+
for host_str, freq in expected_frequencies.items():
71+
address = clean_node(host_str)
72+
actual_freq = float(counts[address]) / iterations
73+
if freq == 0:
74+
# Should be exactly 0.
75+
self.assertEqual(actual_freq, 0)
76+
else:
77+
# Should be within 'tolerance'.
78+
self.assertAlmostEqual(actual_freq, freq, delta=tolerance)
79+
80+
81+
def create_test(scenario_def, test, name):
82+
def run_scenario(self):
83+
self.run_scenario(scenario_def)
84+
85+
return run_scenario
86+
87+
88+
class CustomSpecTestCreator(SpecTestCreator):
89+
def tests(self, scenario_def):
90+
"""Extract the tests from a spec file.
91+
92+
Server selection in_window tests do not have a 'tests' field.
93+
The whole file represents a single test case.
94+
"""
95+
return [scenario_def]
96+
97+
98+
CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests()
99+
100+
if _IS_SYNC:
101+
PARENT = threading.Thread
102+
else:
103+
PARENT = object
104+
105+
106+
class FinderThread(PARENT):
107+
def __init__(self, collection, iterations):
108+
super().__init__()
109+
self.daemon = True
110+
self.collection = collection
111+
self.iterations = iterations
112+
self.passed = False
113+
self.task = None
114+
115+
async def run(self):
116+
for _ in range(self.iterations):
117+
await self.collection.find_one({})
118+
self.passed = True
119+
120+
def start(self):
121+
if _IS_SYNC:
122+
super().start()
123+
else:
124+
self.task = asyncio.create_task(self.run())
125+
126+
async def join(self):
127+
if _IS_SYNC:
128+
super().join()
129+
else:
130+
await self.task
131+
132+
133+
class TestProse(AsyncIntegrationTest):
134+
def frequencies(self, client, listener, n_finds=10):
135+
coll = client.test.test
136+
N_TASKS = 10
137+
tasks = [FinderThread(coll, n_finds) for _ in range(N_TASKS)]
138+
for task in tasks:
139+
task.start()
140+
for task in tasks:
141+
task.join()
142+
for task in tasks:
143+
self.assertTrue(task.passed)
144+
145+
events = listener.started_events
146+
self.assertEqual(len(events), n_finds * N_TASKS)
147+
nodes = client.nodes
148+
self.assertEqual(len(nodes), 2)
149+
freqs = {address: 0.0 for address in nodes}
150+
for event in events:
151+
freqs[event.connection_id] += 1
152+
for address in freqs:
153+
freqs[address] = freqs[address] / float(len(events))
154+
return freqs
155+
156+
@async_client_context.require_failCommand_appName
157+
@async_client_context.require_multiple_mongoses
158+
async def test_load_balancing(self):
159+
listener = OvertCommandListener()
160+
cmap_listener = CMAPListener()
161+
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
162+
# varying RTTs.
163+
client = await self.async_rs_client(
164+
async_client_context.mongos_seeds(),
165+
appName="loadBalancingTest",
166+
event_listeners=[listener, cmap_listener],
167+
localThresholdMS=30000,
168+
minPoolSize=10,
169+
)
170+
await async_wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
171+
# Wait for both pools to be populated.
172+
cmap_listener.wait_for_event(ConnectionReadyEvent, 20)
173+
# Delay find commands on only one mongos.
174+
delay_finds = {
175+
"configureFailPoint": "failCommand",
176+
"mode": {"times": 10000},
177+
"data": {
178+
"failCommands": ["find"],
179+
"blockAsyncConnection": True,
180+
"blockTimeMS": 500,
181+
"appName": "loadBalancingTest",
182+
},
183+
}
184+
with self.fail_point(delay_finds):
185+
nodes = async_client_context.client.nodes
186+
self.assertEqual(len(nodes), 1)
187+
delayed_server = await anext(iter(nodes))
188+
freqs = self.frequencies(client, listener)
189+
self.assertLessEqual(freqs[delayed_server], 0.25)
190+
listener.reset()
191+
freqs = self.frequencies(client, listener, n_finds=150)
192+
self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15)
193+
194+
195+
if __name__ == "__main__":
196+
unittest.main()

test/test_server_selection_in_window.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
"""Test the topology module's Server Selection Spec implementation."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import os
1920
import threading
21+
from pathlib import Path
2022
from test import IntegrationTest, client_context, unittest
2123
from test.utils import (
2224
CMAPListener,
@@ -32,10 +34,14 @@
3234
from pymongo.operations import _Op
3335
from pymongo.read_preferences import ReadPreference
3436

37+
_IS_SYNC = True
3538
# Location of JSON test specifications.
36-
TEST_PATH = os.path.join(
37-
os.path.dirname(os.path.realpath(__file__)), os.path.join("server_selection", "in_window")
38-
)
39+
if _IS_SYNC:
40+
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window")
41+
else:
42+
TEST_PATH = os.path.join(
43+
Path(__file__).resolve().parent.parent, "server_selection", "in_window"
44+
)
3945

4046

4147
class TestAllScenarios(unittest.TestCase):
@@ -91,35 +97,53 @@ def tests(self, scenario_def):
9197

9298
CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests()
9399

100+
if _IS_SYNC:
101+
PARENT = threading.Thread
102+
else:
103+
PARENT = object
104+
94105

95-
class FinderThread(threading.Thread):
106+
class FinderThread(PARENT):
96107
def __init__(self, collection, iterations):
97108
super().__init__()
98109
self.daemon = True
99110
self.collection = collection
100111
self.iterations = iterations
101112
self.passed = False
113+
self.task = None
102114

103115
def run(self):
104116
for _ in range(self.iterations):
105117
self.collection.find_one({})
106118
self.passed = True
107119

120+
def start(self):
121+
if _IS_SYNC:
122+
super().start()
123+
else:
124+
self.task = asyncio.create_task(self.run())
125+
126+
def join(self):
127+
if _IS_SYNC:
128+
super().join()
129+
else:
130+
self.task
131+
108132

109133
class TestProse(IntegrationTest):
110134
def frequencies(self, client, listener, n_finds=10):
111135
coll = client.test.test
112-
N_THREADS = 10
113-
threads = [FinderThread(coll, n_finds) for _ in range(N_THREADS)]
114-
for thread in threads:
115-
thread.start()
116-
for thread in threads:
117-
thread.join()
118-
for thread in threads:
119-
self.assertTrue(thread.passed)
136+
N_TASKS = 10
137+
tasks = [FinderThread(coll, n_finds) for _ in range(N_TASKS)]
138+
for task in tasks:
139+
task.start()
140+
for task in tasks:
141+
task.join()
142+
for task in tasks:
143+
self.assertTrue(task.passed)
120144

121145
events = listener.started_events
122-
self.assertEqual(len(events), n_finds * N_THREADS)
146+
self.assertEqual(len(events), n_finds * N_TASKS)
123147
nodes = client.nodes
124148
self.assertEqual(len(nodes), 2)
125149
freqs = {address: 0.0 for address in nodes}

tools/synchro.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def async_only_test(f: str) -> bool:
224224
"test_retryable_reads_unified.py",
225225
"test_retryable_writes.py",
226226
"test_retryable_writes_unified.py",
227+
"test_server_selection_in_window.py",
227228
"test_session.py",
228229
"test_transactions.py",
229230
"unified_format.py",

0 commit comments

Comments
 (0)