-
Notifications
You must be signed in to change notification settings - Fork 1.1k
PYTHON-3636 MongoClient should perform SRV resolution lazily #2191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 61 commits
ead780a
79c09ea
3afd732
7d771cb
0f64689
ed50141
ed25867
8d48f44
1a3efed
d94743b
ad20606
dfa0639
57edcbc
58a58a0
35a41e9
d343311
d03c78f
e1d091f
8efd549
4c06dec
511fcc4
40509a1
97e0778
1de56d4
2653a56
4c23ee0
bc61199
8c2b368
e38c2ad
3c1bb28
94fec44
99a07fe
32fabb9
af568da
82bcd38
60bf17d
63ba7be
7585e04
2c69412
f834b89
d450457
2a8b1b2
c82cf50
9256808
c6d2ceb
d8d2c26
8927a27
76a68b2
b60eb60
0b6d303
0ca6afd
259d36b
d616135
379dfb6
5466484
2900718
63676b6
cd9bd92
f33d091
a7c090d
93bc3c9
afd82f4
99a5c8a
21d3f58
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,158 @@ | ||||||
# Copyright 2019-present MongoDB, Inc. | ||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); you | ||||||
# may not use this file except in compliance with the License. You | ||||||
# may obtain a copy of the License at | ||||||
# | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||||
# implied. See the License for the specific language governing | ||||||
# permissions and limitations under the License. | ||||||
|
||||||
"""Support for resolving hosts and options from mongodb+srv:// URIs.""" | ||||||
from __future__ import annotations | ||||||
|
||||||
import ipaddress | ||||||
import random | ||||||
from typing import TYPE_CHECKING, Any, Optional, Union | ||||||
|
||||||
from pymongo.common import CONNECT_TIMEOUT | ||||||
from pymongo.errors import ConfigurationError | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
from dns import resolver | ||||||
|
||||||
_IS_SYNC = False | ||||||
|
||||||
|
||||||
def _have_dnspython() -> bool: | ||||||
try: | ||||||
import dns # noqa: F401 | ||||||
|
||||||
return True | ||||||
except ImportError: | ||||||
return False | ||||||
|
||||||
|
||||||
# dnspython can return bytes or str from various parts | ||||||
# of its API depending on version. We always want str. | ||||||
def maybe_decode(text: Union[str, bytes]) -> str: | ||||||
if isinstance(text, bytes): | ||||||
return text.decode() | ||||||
return text | ||||||
|
||||||
|
||||||
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. | ||||||
async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: | ||||||
if _IS_SYNC: | ||||||
from dns import resolver | ||||||
|
||||||
if hasattr(resolver, "resolve"): | ||||||
# dnspython >= 2 | ||||||
return resolver.resolve(*args, **kwargs) | ||||||
# dnspython 1.X | ||||||
return resolver.query(*args, **kwargs) | ||||||
else: | ||||||
from dns import asyncresolver | ||||||
|
||||||
if hasattr(asyncresolver, "resolve"): | ||||||
# dnspython >= 2 | ||||||
return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] | ||||||
raise ConfigurationError("Upgrade to dnspython version >= 2.0") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This error message should explicitly inform users that they are attempting to use the async API with an old dnspython version. Telling them only to upgrade without any other information is inconsistent with the underlying reason. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a commit suggestion here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
|
||||||
_INVALID_HOST_MSG = ( | ||||||
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. " | ||||||
"Did you mean to use 'mongodb://'?" | ||||||
) | ||||||
|
||||||
|
||||||
class _SrvResolver: | ||||||
def __init__( | ||||||
self, | ||||||
fqdn: str, | ||||||
connect_timeout: Optional[float], | ||||||
srv_service_name: str, | ||||||
srv_max_hosts: int = 0, | ||||||
): | ||||||
self.__fqdn = fqdn | ||||||
self.__srv = srv_service_name | ||||||
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT | ||||||
self.__srv_max_hosts = srv_max_hosts or 0 | ||||||
# Validate the fully qualified domain name. | ||||||
try: | ||||||
ipaddress.ip_address(fqdn) | ||||||
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",)) | ||||||
except ValueError: | ||||||
pass | ||||||
|
||||||
try: | ||||||
self.__plist = self.__fqdn.split(".")[1:] | ||||||
except Exception: | ||||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None | ||||||
self.__slen = len(self.__plist) | ||||||
if self.__slen < 2: | ||||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) | ||||||
|
||||||
async def get_options(self) -> Optional[str]: | ||||||
from dns import resolver | ||||||
|
||||||
try: | ||||||
results = await _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout) | ||||||
except (resolver.NoAnswer, resolver.NXDOMAIN): | ||||||
# No TXT records | ||||||
return None | ||||||
except Exception as exc: | ||||||
raise ConfigurationError(str(exc)) from None | ||||||
if len(results) > 1: | ||||||
raise ConfigurationError("Only one TXT record is supported") | ||||||
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") # type: ignore[attr-defined] | ||||||
|
||||||
async def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer: | ||||||
try: | ||||||
results = await _resolve( | ||||||
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout | ||||||
) | ||||||
except Exception as exc: | ||||||
if not encapsulate_errors: | ||||||
# Raise the original error. | ||||||
raise | ||||||
# Else, raise all errors as ConfigurationError. | ||||||
raise ConfigurationError(str(exc)) from None | ||||||
return results | ||||||
|
||||||
async def _get_srv_response_and_hosts( | ||||||
self, encapsulate_errors: bool | ||||||
) -> tuple[resolver.Answer, list[tuple[str, Any]]]: | ||||||
results = await self._resolve_uri(encapsulate_errors) | ||||||
|
||||||
# Construct address tuples | ||||||
nodes = [ | ||||||
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined] | ||||||
for res in results | ||||||
] | ||||||
|
||||||
# Validate hosts | ||||||
for node in nodes: | ||||||
try: | ||||||
nlist = node[0].lower().split(".")[1:][-self.__slen :] | ||||||
except Exception: | ||||||
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None | ||||||
if self.__plist != nlist: | ||||||
raise ConfigurationError(f"Invalid SRV host: {node[0]}") | ||||||
if self.__srv_max_hosts: | ||||||
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) | ||||||
return results, nodes | ||||||
|
||||||
async def get_hosts(self) -> list[tuple[str, Any]]: | ||||||
_, nodes = await self._get_srv_response_and_hosts(True) | ||||||
return nodes | ||||||
|
||||||
async def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]: | ||||||
results, nodes = await self._get_srv_response_and_hosts(False) | ||||||
rrset = results.rrset | ||||||
ttl = rrset.ttl if rrset else 0 | ||||||
return nodes, ttl |
Uh oh!
There was an error while loading. Please reload this page.