Skip to content

Commit 8c2b368

Browse files
committed
respond to comments and move srv_resolver to async
1 parent bc61199 commit 8c2b368

File tree

11 files changed

+228
-50
lines changed

11 files changed

+228
-50
lines changed

pymongo/asynchronous/mongo_client.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,10 +1216,12 @@ def options(self) -> ClientOptions:
12161216

12171217
def __eq__(self, other: Any) -> bool:
12181218
if isinstance(other, self.__class__):
1219-
if hasattr(self, "_topology"):
1219+
if hasattr(self, "_topology") and hasattr(other, "_topology"):
12201220
return self._topology == other._topology
12211221
else:
1222-
raise InvalidOperation("Cannot compare client equality until both clients are connected")
1222+
raise InvalidOperation(
1223+
"Cannot compare client equality until both clients are connected"
1224+
)
12231225
return NotImplemented
12241226

12251227
def __ne__(self, other: Any) -> bool:
@@ -1245,13 +1247,16 @@ def option_repr(option: str, value: Any) -> str:
12451247
return f"{option}={value!r}"
12461248

12471249
# Host first...
1248-
options = [
1249-
"host=%r"
1250-
% [
1251-
"%s:%d" % (host, port) if port is not None else host
1252-
for host, port in self._topology_settings.seeds
1250+
if hasattr(self, "_topology"):
1251+
options = [
1252+
"host=%r"
1253+
% [
1254+
"%s:%d" % (host, port) if port is not None else host
1255+
for host, port in self._topology_settings.seeds
1256+
]
12531257
]
1254-
]
1258+
else:
1259+
options = []
12551260
# ... then everything in self._constructor_args...
12561261
options.extend(
12571262
option_repr(key, self._options._options[key]) for key in self._constructor_args
@@ -1265,9 +1270,7 @@ def option_repr(option: str, value: Any) -> str:
12651270
return ", ".join(options)
12661271

12671272
def __repr__(self) -> str:
1268-
if hasattr(self, "_topology"):
1269-
return f"{type(self).__name__}({self._repr_helper()})"
1270-
raise InvalidOperation("Cannot perform operation until client is connected")
1273+
return f"{type(self).__name__}({self._repr_helper()})"
12711274

12721275
def __getattr__(self, name: str) -> database.AsyncDatabase[_DocumentType]:
12731276
"""Get a database by name.

pymongo/asynchronous/monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from pymongo import common, periodic_executor
2727
from pymongo._csot import MovingMinimum
28+
from pymongo.asynchronous.srv_resolver import _SrvResolver
2829
from pymongo.errors import NetworkTimeout, _OperationCancelled
2930
from pymongo.hello import Hello
3031
from pymongo.lock import _async_create_lock
@@ -33,7 +34,6 @@
3334
from pymongo.pool_options import _is_faas
3435
from pymongo.read_preferences import MovingAverage
3536
from pymongo.server_description import ServerDescription
36-
from pymongo.srv_resolver import _SrvResolver
3737

3838
if TYPE_CHECKING:
3939
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext

pymongo/asynchronous/srv_resolver.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright 2019-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you
4+
# may not use this file except in compliance with the License. You
5+
# may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12+
# implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
15+
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
16+
from __future__ import annotations
17+
18+
import ipaddress
19+
import random
20+
from typing import TYPE_CHECKING, Any, Optional, Union
21+
22+
from pymongo.common import CONNECT_TIMEOUT
23+
from pymongo.errors import ConfigurationError
24+
25+
if TYPE_CHECKING:
26+
from dns import resolver
27+
28+
_IS_SYNC = False
29+
30+
31+
def _have_dnspython() -> bool:
32+
try:
33+
import dns # noqa: F401
34+
35+
return True
36+
except ImportError:
37+
return False
38+
39+
40+
# dnspython can return bytes or str from various parts
41+
# of its API depending on version. We always want str.
42+
def maybe_decode(text: Union[str, bytes]) -> str:
43+
if isinstance(text, bytes):
44+
return text.decode()
45+
return text
46+
47+
48+
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
49+
async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
50+
if _IS_SYNC:
51+
from dns import resolver
52+
53+
if hasattr(resolver, "resolve"):
54+
# dnspython >= 2
55+
return resolver.resolve(*args, **kwargs)
56+
# dnspython 1.X
57+
return resolver.query(*args, **kwargs)
58+
else:
59+
from dns.asyncresolver import Resolver
60+
61+
return await Resolver.resolve(*args, **kwargs)
62+
63+
64+
_INVALID_HOST_MSG = (
65+
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
66+
"Did you mean to use 'mongodb://'?"
67+
)
68+
69+
70+
class _SrvResolver:
71+
def __init__(
72+
self,
73+
fqdn: str,
74+
connect_timeout: Optional[float],
75+
srv_service_name: str,
76+
srv_max_hosts: int = 0,
77+
):
78+
self.__fqdn = fqdn
79+
self.__srv = srv_service_name
80+
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
81+
self.__srv_max_hosts = srv_max_hosts or 0
82+
# Validate the fully qualified domain name.
83+
try:
84+
ipaddress.ip_address(fqdn)
85+
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
86+
except ValueError:
87+
pass
88+
89+
try:
90+
self.__plist = self.__fqdn.split(".")[1:]
91+
except Exception:
92+
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
93+
self.__slen = len(self.__plist)
94+
if self.__slen < 2:
95+
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
96+
97+
async def get_options(self) -> Optional[str]:
98+
from dns import resolver
99+
100+
try:
101+
results = await _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
102+
except (resolver.NoAnswer, resolver.NXDOMAIN):
103+
# No TXT records
104+
return None
105+
except Exception as exc:
106+
raise ConfigurationError(str(exc)) from None
107+
if len(results) > 1:
108+
raise ConfigurationError("Only one TXT record is supported")
109+
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") # type: ignore[attr-defined]
110+
111+
async def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer:
112+
try:
113+
results = await _resolve(
114+
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
115+
)
116+
except Exception as exc:
117+
if not encapsulate_errors:
118+
# Raise the original error.
119+
raise
120+
# Else, raise all errors as ConfigurationError.
121+
raise ConfigurationError(str(exc)) from None
122+
return results
123+
124+
async def _get_srv_response_and_hosts(
125+
self, encapsulate_errors: bool
126+
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
127+
results = await self._resolve_uri(encapsulate_errors)
128+
129+
# Construct address tuples
130+
nodes = [
131+
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined]
132+
for res in results
133+
]
134+
135+
# Validate hosts
136+
for node in nodes:
137+
try:
138+
nlist = node[0].lower().split(".")[1:][-self.__slen :]
139+
except Exception:
140+
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
141+
if self.__plist != nlist:
142+
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
143+
if self.__srv_max_hosts:
144+
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
145+
return results, nodes
146+
147+
async def get_hosts(self) -> list[tuple[str, Any]]:
148+
_, nodes = await self._get_srv_response_and_hosts(True)
149+
return nodes
150+
151+
async def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]:
152+
results, nodes = await self._get_srv_response_and_hosts(False)
153+
rrset = results.rrset
154+
ttl = rrset.ttl if rrset else 0
155+
return nodes, ttl

pymongo/synchronous/mongo_client.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,10 +1214,12 @@ def options(self) -> ClientOptions:
12141214

12151215
def __eq__(self, other: Any) -> bool:
12161216
if isinstance(other, self.__class__):
1217-
if hasattr(self, "_topology"):
1217+
if hasattr(self, "_topology") and hasattr(other, "_topology"):
12181218
return self._topology == other._topology
12191219
else:
1220-
raise InvalidOperation("Cannot perform operation until client is connected")
1220+
raise InvalidOperation(
1221+
"Cannot compare client equality until both clients are connected"
1222+
)
12211223
return NotImplemented
12221224

12231225
def __ne__(self, other: Any) -> bool:
@@ -1227,7 +1229,7 @@ def __hash__(self) -> int:
12271229
if hasattr(self, "_topology"):
12281230
return hash(self._topology)
12291231
else:
1230-
raise InvalidOperation("Cannot perform operation until client is connected")
1232+
raise InvalidOperation("Cannot hash client until it is connected")
12311233

12321234
def _repr_helper(self) -> str:
12331235
def option_repr(option: str, value: Any) -> str:
@@ -1243,13 +1245,16 @@ def option_repr(option: str, value: Any) -> str:
12431245
return f"{option}={value!r}"
12441246

12451247
# Host first...
1246-
options = [
1247-
"host=%r"
1248-
% [
1249-
"%s:%d" % (host, port) if port is not None else host
1250-
for host, port in self._topology_settings.seeds
1248+
if hasattr(self, "_topology"):
1249+
options = [
1250+
"host=%r"
1251+
% [
1252+
"%s:%d" % (host, port) if port is not None else host
1253+
for host, port in self._topology_settings.seeds
1254+
]
12511255
]
1252-
]
1256+
else:
1257+
options = []
12531258
# ... then everything in self._constructor_args...
12541259
options.extend(
12551260
option_repr(key, self._options._options[key]) for key in self._constructor_args
@@ -1263,9 +1268,7 @@ def option_repr(option: str, value: Any) -> str:
12631268
return ", ".join(options)
12641269

12651270
def __repr__(self) -> str:
1266-
if hasattr(self, "_topology"):
1267-
return f"{type(self).__name__}({self._repr_helper()})"
1268-
raise InvalidOperation("Cannot perform operation until client is connected")
1271+
return f"{type(self).__name__}({self._repr_helper()})"
12691272

12701273
def __getattr__(self, name: str) -> database.Database[_DocumentType]:
12711274
"""Get a database by name.

pymongo/synchronous/monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from pymongo.pool_options import _is_faas
3434
from pymongo.read_preferences import MovingAverage
3535
from pymongo.server_description import ServerDescription
36-
from pymongo.srv_resolver import _SrvResolver
36+
from pymongo.synchronous.srv_resolver import _SrvResolver
3737

3838
if TYPE_CHECKING:
3939
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext

pymongo/srv_resolver.py renamed to pymongo/synchronous/srv_resolver.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
if TYPE_CHECKING:
2626
from dns import resolver
2727

28+
_IS_SYNC = True
29+
2830

2931
def _have_dnspython() -> bool:
3032
try:
@@ -45,13 +47,18 @@ def maybe_decode(text: Union[str, bytes]) -> str:
4547

4648
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
4749
def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
48-
from dns import resolver
50+
if _IS_SYNC:
51+
from dns import resolver
52+
53+
if hasattr(resolver, "resolve"):
54+
# dnspython >= 2
55+
return resolver.resolve(*args, **kwargs)
56+
# dnspython 1.X
57+
return resolver.query(*args, **kwargs)
58+
else:
59+
from dns.asyncresolver import Resolver
4960

50-
if hasattr(resolver, "resolve"):
51-
# dnspython >= 2
52-
return resolver.resolve(*args, **kwargs)
53-
# dnspython 1.X
54-
return resolver.query(*args, **kwargs)
61+
return Resolver.resolve(*args, **kwargs)
5562

5663

5764
_INVALID_HOST_MSG = (

pymongo/uri_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from urllib.parse import unquote_plus
3636

37+
from pymongo.asynchronous.srv_resolver import _have_dnspython, _SrvResolver
3738
from pymongo.client_options import _parse_ssl_options
3839
from pymongo.common import (
3940
INTERNAL_URI_OPTION_NAME_MAP,
@@ -43,7 +44,6 @@
4344
get_validated_options,
4445
)
4546
from pymongo.errors import ConfigurationError, InvalidURI
46-
from pymongo.srv_resolver import _have_dnspython, _SrvResolver
4747
from pymongo.typings import _Address
4848

4949
if TYPE_CHECKING:

test/asynchronous/test_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -512,13 +512,13 @@ async def test_uri_option_precedence(self):
512512

513513
async def test_connection_timeout_ms_propagates_to_DNS_resolver(self):
514514
# Patch the resolver.
515-
from pymongo.srv_resolver import _resolve
515+
from pymongo.asynchronous.srv_resolver import _resolve
516516

517517
patched_resolver = FunctionCallRecorder(_resolve)
518-
pymongo.srv_resolver._resolve = patched_resolver
518+
pymongo.asynchronous.srv_resolver._resolve = patched_resolver
519519

520520
def reset_resolver():
521-
pymongo.srv_resolver._resolve = _resolve
521+
pymongo.asynchronous.srv_resolver._resolve = _resolve
522522

523523
self.addCleanup(reset_resolver)
524524

0 commit comments

Comments
 (0)