Skip to content

Commit 64aa048

Browse files
committed
create utils_selection_tests_shared.py
1 parent 7aa61e4 commit 64aa048

File tree

3 files changed

+112
-154
lines changed

3 files changed

+112
-154
lines changed

test/asynchronous/utils_selection_tests.py

Lines changed: 6 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -24,92 +24,21 @@
2424
from test import unittest
2525
from test.pymongo_mocks import DummyMonitor
2626
from test.utils import AsyncMockPool, parse_read_preference
27+
from test.utils_selection_tests_shared import (
28+
get_addresses,
29+
get_topology_type_name,
30+
make_server_description,
31+
)
2732

2833
from bson import json_util
2934
from pymongo.asynchronous.settings import TopologySettings
3035
from pymongo.asynchronous.topology import Topology
31-
from pymongo.common import HEARTBEAT_FREQUENCY, MIN_SUPPORTED_WIRE_VERSION, clean_node
36+
from pymongo.common import HEARTBEAT_FREQUENCY
3237
from pymongo.errors import AutoReconnect, ConfigurationError
33-
from pymongo.hello import Hello, HelloCompat
3438
from pymongo.operations import _Op
35-
from pymongo.server_description import ServerDescription
3639
from pymongo.server_selectors import writable_server_selector
3740

3841

39-
def get_addresses(server_list):
40-
seeds = []
41-
hosts = []
42-
for server in server_list:
43-
seeds.append(clean_node(server["address"]))
44-
hosts.append(server["address"])
45-
return seeds, hosts
46-
47-
48-
def make_last_write_date(server):
49-
epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None)
50-
millis = server.get("lastWrite", {}).get("lastWriteDate")
51-
if millis:
52-
diff = ((millis % 1000) + 1000) % 1000
53-
seconds = (millis - diff) / 1000
54-
micros = diff * 1000
55-
return epoch + datetime.timedelta(seconds=seconds, microseconds=micros)
56-
else:
57-
# "Unknown" server.
58-
return epoch
59-
60-
61-
def make_server_description(server, hosts):
62-
"""Make a ServerDescription from server info in a JSON test."""
63-
server_type = server["type"]
64-
if server_type in ("Unknown", "PossiblePrimary"):
65-
return ServerDescription(clean_node(server["address"]), Hello({}))
66-
67-
hello_response = {"ok": True, "hosts": hosts}
68-
if server_type not in ("Standalone", "Mongos", "RSGhost"):
69-
hello_response["setName"] = "rs"
70-
71-
if server_type == "RSPrimary":
72-
hello_response[HelloCompat.LEGACY_CMD] = True
73-
elif server_type == "RSSecondary":
74-
hello_response["secondary"] = True
75-
elif server_type == "Mongos":
76-
hello_response["msg"] = "isdbgrid"
77-
elif server_type == "RSGhost":
78-
hello_response["isreplicaset"] = True
79-
elif server_type == "RSArbiter":
80-
hello_response["arbiterOnly"] = True
81-
82-
hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)}
83-
84-
for field in "maxWireVersion", "tags", "idleWritePeriodMillis":
85-
if field in server:
86-
hello_response[field] = server[field]
87-
88-
hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION)
89-
90-
# Sets _last_update_time to now.
91-
sd = ServerDescription(
92-
clean_node(server["address"]),
93-
Hello(hello_response),
94-
round_trip_time=server["avg_rtt_ms"] / 1000.0,
95-
)
96-
97-
if "lastUpdateTime" in server:
98-
sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec.
99-
100-
return sd
101-
102-
103-
def get_topology_type_name(scenario_def):
104-
td = scenario_def["topology_description"]
105-
name = td["type"]
106-
if name == "Unknown":
107-
# PyMongo never starts a topology in type Unknown.
108-
return "Sharded" if len(td["servers"]) > 1 else "Single"
109-
else:
110-
return name
111-
112-
11342
def get_topology_settings_dict(**kwargs):
11443
settings = {
11544
"monitor_class": DummyMonitor,

test/utils_selection_tests.py

Lines changed: 6 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -24,92 +24,21 @@
2424
from test import unittest
2525
from test.pymongo_mocks import DummyMonitor
2626
from test.utils import MockPool, parse_read_preference
27+
from test.utils_selection_tests_shared import (
28+
get_addresses,
29+
get_topology_type_name,
30+
make_server_description,
31+
)
2732

2833
from bson import json_util
29-
from pymongo.common import HEARTBEAT_FREQUENCY, MIN_SUPPORTED_WIRE_VERSION, clean_node
34+
from pymongo.common import HEARTBEAT_FREQUENCY
3035
from pymongo.errors import AutoReconnect, ConfigurationError
31-
from pymongo.hello import Hello, HelloCompat
3236
from pymongo.operations import _Op
33-
from pymongo.server_description import ServerDescription
3437
from pymongo.server_selectors import writable_server_selector
3538
from pymongo.synchronous.settings import TopologySettings
3639
from pymongo.synchronous.topology import Topology
3740

3841

39-
def get_addresses(server_list):
40-
seeds = []
41-
hosts = []
42-
for server in server_list:
43-
seeds.append(clean_node(server["address"]))
44-
hosts.append(server["address"])
45-
return seeds, hosts
46-
47-
48-
def make_last_write_date(server):
49-
epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None)
50-
millis = server.get("lastWrite", {}).get("lastWriteDate")
51-
if millis:
52-
diff = ((millis % 1000) + 1000) % 1000
53-
seconds = (millis - diff) / 1000
54-
micros = diff * 1000
55-
return epoch + datetime.timedelta(seconds=seconds, microseconds=micros)
56-
else:
57-
# "Unknown" server.
58-
return epoch
59-
60-
61-
def make_server_description(server, hosts):
62-
"""Make a ServerDescription from server info in a JSON test."""
63-
server_type = server["type"]
64-
if server_type in ("Unknown", "PossiblePrimary"):
65-
return ServerDescription(clean_node(server["address"]), Hello({}))
66-
67-
hello_response = {"ok": True, "hosts": hosts}
68-
if server_type not in ("Standalone", "Mongos", "RSGhost"):
69-
hello_response["setName"] = "rs"
70-
71-
if server_type == "RSPrimary":
72-
hello_response[HelloCompat.LEGACY_CMD] = True
73-
elif server_type == "RSSecondary":
74-
hello_response["secondary"] = True
75-
elif server_type == "Mongos":
76-
hello_response["msg"] = "isdbgrid"
77-
elif server_type == "RSGhost":
78-
hello_response["isreplicaset"] = True
79-
elif server_type == "RSArbiter":
80-
hello_response["arbiterOnly"] = True
81-
82-
hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)}
83-
84-
for field in "maxWireVersion", "tags", "idleWritePeriodMillis":
85-
if field in server:
86-
hello_response[field] = server[field]
87-
88-
hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION)
89-
90-
# Sets _last_update_time to now.
91-
sd = ServerDescription(
92-
clean_node(server["address"]),
93-
Hello(hello_response),
94-
round_trip_time=server["avg_rtt_ms"] / 1000.0,
95-
)
96-
97-
if "lastUpdateTime" in server:
98-
sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec.
99-
100-
return sd
101-
102-
103-
def get_topology_type_name(scenario_def):
104-
td = scenario_def["topology_description"]
105-
name = td["type"]
106-
if name == "Unknown":
107-
# PyMongo never starts a topology in type Unknown.
108-
return "Sharded" if len(td["servers"]) > 1 else "Single"
109-
else:
110-
return name
111-
112-
11342
def get_topology_settings_dict(**kwargs):
11443
settings = {
11544
"monitor_class": DummyMonitor,

test/utils_selection_tests_shared.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2015-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utilities for testing Server Selection and Max Staleness."""
16+
from __future__ import annotations
17+
18+
import datetime
19+
import os
20+
import sys
21+
22+
sys.path[0:0] = [""]
23+
24+
from pymongo.common import MIN_SUPPORTED_WIRE_VERSION, clean_node
25+
from pymongo.hello import Hello, HelloCompat
26+
from pymongo.server_description import ServerDescription
27+
28+
29+
def get_addresses(server_list):
30+
seeds = []
31+
hosts = []
32+
for server in server_list:
33+
seeds.append(clean_node(server["address"]))
34+
hosts.append(server["address"])
35+
return seeds, hosts
36+
37+
38+
def make_last_write_date(server):
39+
epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None)
40+
millis = server.get("lastWrite", {}).get("lastWriteDate")
41+
if millis:
42+
diff = ((millis % 1000) + 1000) % 1000
43+
seconds = (millis - diff) / 1000
44+
micros = diff * 1000
45+
return epoch + datetime.timedelta(seconds=seconds, microseconds=micros)
46+
else:
47+
# "Unknown" server.
48+
return epoch
49+
50+
51+
def make_server_description(server, hosts):
52+
"""Make a ServerDescription from server info in a JSON test."""
53+
server_type = server["type"]
54+
if server_type in ("Unknown", "PossiblePrimary"):
55+
return ServerDescription(clean_node(server["address"]), Hello({}))
56+
57+
hello_response = {"ok": True, "hosts": hosts}
58+
if server_type not in ("Standalone", "Mongos", "RSGhost"):
59+
hello_response["setName"] = "rs"
60+
61+
if server_type == "RSPrimary":
62+
hello_response[HelloCompat.LEGACY_CMD] = True
63+
elif server_type == "RSSecondary":
64+
hello_response["secondary"] = True
65+
elif server_type == "Mongos":
66+
hello_response["msg"] = "isdbgrid"
67+
elif server_type == "RSGhost":
68+
hello_response["isreplicaset"] = True
69+
elif server_type == "RSArbiter":
70+
hello_response["arbiterOnly"] = True
71+
72+
hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)}
73+
74+
for field in "maxWireVersion", "tags", "idleWritePeriodMillis":
75+
if field in server:
76+
hello_response[field] = server[field]
77+
78+
hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION)
79+
80+
# Sets _last_update_time to now.
81+
sd = ServerDescription(
82+
clean_node(server["address"]),
83+
Hello(hello_response),
84+
round_trip_time=server["avg_rtt_ms"] / 1000.0,
85+
)
86+
87+
if "lastUpdateTime" in server:
88+
sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec.
89+
90+
return sd
91+
92+
93+
def get_topology_type_name(scenario_def):
94+
td = scenario_def["topology_description"]
95+
name = td["type"]
96+
if name == "Unknown":
97+
# PyMongo never starts a topology in type Unknown.
98+
return "Sharded" if len(td["servers"]) > 1 else "Single"
99+
else:
100+
return name

0 commit comments

Comments
 (0)