Skip to content

Commit ff0d5dc

Browse files
committed
Convert test.test_mongos_load_balancing to async
1 parent c8d3afd commit ff0d5dc

File tree

3 files changed

+268
-14
lines changed

3 files changed

+268
-14
lines changed
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright 2015-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test AsyncMongoClient's mongos load balancing using a mock."""
16+
from __future__ import annotations
17+
18+
import asyncio
19+
import sys
20+
import threading
21+
22+
from pymongo.operations import _Op
23+
24+
sys.path[0:0] = [""]
25+
26+
from test.asynchronous import AsyncMockClientTest, async_client_context, connected, unittest
27+
from test.asynchronous.pymongo_mocks import AsyncMockClient
28+
from test.utils import async_wait_until
29+
30+
from pymongo.errors import AutoReconnect, InvalidOperation
31+
from pymongo.server_selectors import writable_server_selector
32+
from pymongo.topology_description import TOPOLOGY_TYPE
33+
34+
_IS_SYNC = False
35+
36+
37+
@async_client_context.require_connection
38+
@async_client_context.require_no_load_balancer
39+
def asyncSetUpModule():
40+
pass
41+
42+
43+
if _IS_SYNC:
44+
45+
class SimpleOp(threading.Thread):
46+
def __init__(self, client):
47+
super().__init__()
48+
self.client = client
49+
self.passed = False
50+
51+
def run(self):
52+
self.client.db.command("ping")
53+
self.passed = True # No exception raised.
54+
else:
55+
56+
class SimpleOp:
57+
def __init__(self, client):
58+
self.task = asyncio.create_task(self.run())
59+
self.client = client
60+
self.passed = False
61+
62+
async def run(self):
63+
await self.client.db.command("ping")
64+
self.passed = True # No exception raised.
65+
66+
def start(self):
67+
pass
68+
69+
async def join(self):
70+
await self.task
71+
72+
73+
async def do_simple_op(client, nthreads):
74+
threads = [SimpleOp(client) for _ in range(nthreads)]
75+
for t in threads:
76+
t.start()
77+
78+
for t in threads:
79+
await t.join()
80+
81+
for t in threads:
82+
assert t.passed
83+
84+
85+
async def writable_addresses(topology):
86+
return {
87+
server.description.address
88+
for server in await topology.select_servers(writable_server_selector, _Op.TEST)
89+
}
90+
91+
92+
class TestMongosLoadBalancing(AsyncMockClientTest):
93+
def mock_client(self, **kwargs):
94+
mock_client = AsyncMockClient(
95+
standalones=[],
96+
members=[],
97+
mongoses=["a:1", "b:2", "c:3"],
98+
host="a:1,b:2,c:3",
99+
connect=False,
100+
**kwargs,
101+
)
102+
self.addAsyncCleanup(mock_client.aclose)
103+
104+
# Latencies in seconds.
105+
mock_client.mock_rtts["a:1"] = 0.020
106+
mock_client.mock_rtts["b:2"] = 0.025
107+
mock_client.mock_rtts["c:3"] = 0.045
108+
return mock_client
109+
110+
async def test_lazy_connect(self):
111+
# While connected() ensures we can trigger connection from the main
112+
# thread and wait for the monitors, this test triggers connection from
113+
# several threads at once to check for data races.
114+
nthreads = 10
115+
client = self.mock_client()
116+
self.assertEqual(0, len(client.nodes))
117+
118+
# Trigger initial connection.
119+
await do_simple_op(client, nthreads)
120+
await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses")
121+
122+
async def test_failover(self):
123+
nthreads = 10
124+
client = await connected(self.mock_client(localThresholdMS=0.001))
125+
await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses")
126+
127+
# Our chosen mongos goes down.
128+
client.kill_host("a:1")
129+
130+
# Trigger failover to higher-latency nodes. AutoReconnect should be
131+
# raised at most once in each thread.
132+
passed = []
133+
134+
async def f():
135+
try:
136+
await client.db.command("ping")
137+
except AutoReconnect:
138+
# Second attempt succeeds.
139+
await client.db.command("ping")
140+
141+
passed.append(True)
142+
143+
if _IS_SYNC:
144+
threads = [threading.Thread(target=f) for _ in range(nthreads)]
145+
for t in threads:
146+
t.start()
147+
148+
for t in threads:
149+
t.join()
150+
else:
151+
tasks = [asyncio.create_task(f()) for _ in range(nthreads)]
152+
for t in tasks:
153+
await t
154+
155+
self.assertEqual(nthreads, len(passed))
156+
157+
# Down host removed from list.
158+
self.assertEqual(2, len(client.nodes))
159+
160+
async def test_local_threshold(self):
161+
client = await connected(self.mock_client(localThresholdMS=30))
162+
self.assertEqual(30, client.options.local_threshold_ms)
163+
await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses")
164+
topology = client._topology
165+
166+
# All are within a 30-ms latency window, see self.mock_client().
167+
self.assertEqual({("a", 1), ("b", 2), ("c", 3)}, await writable_addresses(topology))
168+
169+
# No error
170+
await client.admin.command("ping")
171+
172+
client = await connected(self.mock_client(localThresholdMS=0))
173+
self.assertEqual(0, client.options.local_threshold_ms)
174+
# No error
175+
await client.db.command("ping")
176+
# Our chosen mongos goes down.
177+
client.kill_host("{}:{}".format(*next(iter(client.nodes))))
178+
try:
179+
await client.db.command("ping")
180+
except:
181+
pass
182+
183+
# We eventually connect to a new mongos.
184+
async def connect_to_new_mongos():
185+
try:
186+
return await client.db.command("ping")
187+
except AutoReconnect:
188+
pass
189+
190+
await async_wait_until(connect_to_new_mongos, "connect to a new mongos")
191+
192+
async def test_load_balancing(self):
193+
# Although the server selection JSON tests already prove that
194+
# select_servers works for sharded topologies, here we do an end-to-end
195+
# test of discovering servers' round trip times and configuring
196+
# localThresholdMS.
197+
client = await connected(self.mock_client())
198+
await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses")
199+
200+
# Prohibited for topology type Sharded.
201+
with self.assertRaises(InvalidOperation):
202+
await client.address
203+
204+
topology = client._topology
205+
self.assertEqual(TOPOLOGY_TYPE.Sharded, topology.description.topology_type)
206+
207+
# a and b are within the 15-ms latency window, see self.mock_client().
208+
self.assertEqual({("a", 1), ("b", 2)}, await writable_addresses(topology))
209+
210+
client.mock_rtts["a:1"] = 0.045
211+
212+
# Discover only b is within latency window.
213+
async def predicate():
214+
return {("b", 2)} == await writable_addresses(topology)
215+
216+
await async_wait_until(
217+
predicate,
218+
'discover server "a" is too far',
219+
)
220+
221+
222+
if __name__ == "__main__":
223+
unittest.main()

test/test_mongos_load_balancing.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Test MongoClient's mongos load balancing using a mock."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import sys
1920
import threading
2021

@@ -30,22 +31,43 @@
3031
from pymongo.server_selectors import writable_server_selector
3132
from pymongo.topology_description import TOPOLOGY_TYPE
3233

34+
_IS_SYNC = True
35+
3336

3437
@client_context.require_connection
3538
@client_context.require_no_load_balancer
3639
def setUpModule():
3740
pass
3841

3942

40-
class SimpleOp(threading.Thread):
41-
def __init__(self, client):
42-
super().__init__()
43-
self.client = client
44-
self.passed = False
43+
if _IS_SYNC:
44+
45+
class SimpleOp(threading.Thread):
46+
def __init__(self, client):
47+
super().__init__()
48+
self.client = client
49+
self.passed = False
50+
51+
def run(self):
52+
self.client.db.command("ping")
53+
self.passed = True # No exception raised.
54+
else:
55+
56+
class SimpleOp:
57+
def __init__(self, client):
58+
self.task = asyncio.create_task(self.run())
59+
self.client = client
60+
self.passed = False
4561

46-
def run(self):
47-
self.client.db.command("ping")
48-
self.passed = True # No exception raised.
62+
def run(self):
63+
self.client.db.command("ping")
64+
self.passed = True # No exception raised.
65+
66+
def start(self):
67+
pass
68+
69+
def join(self):
70+
self.task
4971

5072

5173
def do_simple_op(client, nthreads):
@@ -118,12 +140,17 @@ def f():
118140

119141
passed.append(True)
120142

121-
threads = [threading.Thread(target=f) for _ in range(nthreads)]
122-
for t in threads:
123-
t.start()
143+
if _IS_SYNC:
144+
threads = [threading.Thread(target=f) for _ in range(nthreads)]
145+
for t in threads:
146+
t.start()
124147

125-
for t in threads:
126-
t.join()
148+
for t in threads:
149+
t.join()
150+
else:
151+
tasks = [asyncio.create_task(f()) for _ in range(nthreads)]
152+
for t in tasks:
153+
t
127154

128155
self.assertEqual(nthreads, len(passed))
129156

@@ -183,8 +210,11 @@ def test_load_balancing(self):
183210
client.mock_rtts["a:1"] = 0.045
184211

185212
# Discover only b is within latency window.
213+
def predicate():
214+
return {("b", 2)} == writable_addresses(topology)
215+
186216
wait_until(
187-
lambda: {("b", 2)} == writable_addresses(topology),
217+
predicate,
188218
'discover server "a" is too far',
189219
)
190220

tools/synchro.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def async_only_test(f: str) -> bool:
215215
"test_gridfs_spec.py",
216216
"test_logger.py",
217217
"test_monitoring.py",
218+
"test_mongos_load_balancing.py",
218219
"test_raw_bson.py",
219220
"test_retryable_reads.py",
220221
"test_retryable_writes.py",

0 commit comments

Comments
 (0)