Skip to content

Commit 62400e5

Browse files
committed
PYTHON-1902 DNS resolution should honor connectTimeoutMS
1 parent aadd9c7 commit 62400e5

File tree

9 files changed

+128
-52
lines changed

9 files changed

+128
-52
lines changed

doc/changelog.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ Unavoidable breaking changes:
7878
to avoid running into :class:`~pymongo.errors.OperationFailure` exceptions
7979
during write operations. The MMAPv1 storage engine is deprecated and does
8080
not support retryable writes which are now turned on by default.
81+
- In order to ensure that the ``connectTimeoutMS`` URI option is honored when
82+
connecting to clusters with a ``mongodb+srv://`` connection string, the
83+
minimum required version of the optional ``dnspython`` dependency has been
84+
bumped to 1.16.0. This is a breaking change for applications that use
85+
PyMongo's SRV support with a version of ``dnspython`` older than 1.16.0.
8186

8287
.. _URI options specification: https://github.com/mongodb/specifications/blob/master/source/uri-options/uri-options.rst
8388

pymongo/mongo_client.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,16 @@ def __init__(
592592
if not isinstance(port, int):
593593
raise TypeError("port must be an instance of int")
594594

595+
# _pool_class, _monitor_class, and _condition_class are for deep
596+
# customization of PyMongo, e.g. Motor.
597+
pool_class = kwargs.pop('_pool_class', None)
598+
monitor_class = kwargs.pop('_monitor_class', None)
599+
condition_class = kwargs.pop('_condition_class', None)
600+
601+
# Parse options passed as kwargs.
602+
keyword_opts = common._CaseInsensitiveDictionary(kwargs)
603+
keyword_opts['document_class'] = document_class
604+
595605
seeds = set()
596606
username = None
597607
password = None
@@ -600,8 +610,14 @@ def __init__(
600610
fqdn = None
601611
for entity in host:
602612
if "://" in entity:
613+
# Determine connection timeout from kwargs.
614+
timeout = keyword_opts.get("connecttimeoutms")
615+
if timeout is not None:
616+
timeout = common.validate_timeout_or_none(
617+
keyword_opts.cased_key("connecttimeoutms"), timeout)
603618
res = uri_parser.parse_uri(
604-
entity, port, validate=True, warn=True, normalize=False)
619+
entity, port, validate=True, warn=True, normalize=False,
620+
connect_timeout=timeout)
605621
seeds.update(res["nodelist"])
606622
username = res["username"] or username
607623
password = res["password"] or password
@@ -613,14 +629,7 @@ def __init__(
613629
if not seeds:
614630
raise ConfigurationError("need to specify at least one host")
615631

616-
# _pool_class, _monitor_class, and _condition_class are for deep
617-
# customization of PyMongo, e.g. Motor.
618-
pool_class = kwargs.pop('_pool_class', None)
619-
monitor_class = kwargs.pop('_monitor_class', None)
620-
condition_class = kwargs.pop('_condition_class', None)
621-
622-
keyword_opts = common._CaseInsensitiveDictionary(kwargs)
623-
keyword_opts['document_class'] = document_class
632+
# Add options with named keyword arguments to the parsed kwarg options.
624633
if type_registry is not None:
625634
keyword_opts['type_registry'] = type_registry
626635
if tz_aware is None:

pymongo/srv_resolver.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from bson.py3compat import PY3
2424

25+
from pymongo.common import CONNECT_TIMEOUT
2526
from pymongo.errors import ConfigurationError
2627

2728

@@ -38,8 +39,9 @@ def maybe_decode(text):
3839

3940

4041
class _SrvResolver(object):
41-
def __init__(self, fqdn):
42+
def __init__(self, fqdn, connect_timeout=None):
4243
self.__fqdn = fqdn
44+
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
4345

4446
# Validate the fully qualified domain name.
4547
try:
@@ -52,7 +54,8 @@ def __init__(self, fqdn):
5254

5355
def get_options(self):
5456
try:
55-
results = resolver.query(self.__fqdn, 'TXT')
57+
results = resolver.query(self.__fqdn, 'TXT',
58+
lifetime=self.__connect_timeout)
5659
except (resolver.NoAnswer, resolver.NXDOMAIN):
5760
# No TXT records
5861
return None
@@ -66,7 +69,8 @@ def get_options(self):
6669

6770
def _resolve_uri(self, encapsulate_errors):
6871
try:
69-
results = resolver.query('_mongodb._tcp.' + self.__fqdn, 'SRV')
72+
results = resolver.query('_mongodb._tcp.' + self.__fqdn, 'SRV',
73+
lifetime=self.__connect_timeout)
7074
except Exception as exc:
7175
if not encapsulate_errors:
7276
# Raise the original error.

pymongo/uri_parser.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def split_hosts(hosts, default_port=DEFAULT_PORT):
325325

326326

327327
def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
328-
normalize=True):
328+
normalize=True, connect_timeout=None):
329329
"""Parse and validate a MongoDB URI.
330330
331331
Returns a dict of the form::
@@ -355,6 +355,8 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
355355
invalid. Default: ``False``.
356356
- `normalize` (optional): If ``True``, convert names of URI options
357357
to their internally-used names. Default: ``True``.
358+
- `connect_timeout` (optional): The maximum time in milliseconds to
359+
wait for a response from the DNS server.
358360
359361
.. versionchanged:: 3.9
360362
Added the ``normalize`` parameter.
@@ -400,6 +402,25 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
400402
raise InvalidURI("A '/' is required between "
401403
"the host list and any options.")
402404

405+
if path_part:
406+
if path_part[0] == '?':
407+
opts = unquote_plus(path_part[1:])
408+
else:
409+
dbase, _, opts = map(unquote_plus, path_part.partition('?'))
410+
if '.' in dbase:
411+
dbase, collection = dbase.split('.', 1)
412+
413+
if _BAD_DB_CHARS.search(dbase):
414+
raise InvalidURI('Bad database name "%s"' % dbase)
415+
416+
if opts:
417+
options.update(split_options(opts, validate, warn, normalize))
418+
419+
if dbase is not None:
420+
dbase = unquote_plus(dbase)
421+
if collection is not None:
422+
collection = unquote_plus(collection)
423+
403424
if '@' in host_part:
404425
userinfo, _, hosts = host_part.rpartition('@')
405426
user, passwd = parse_userinfo(userinfo)
@@ -424,37 +445,26 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
424445
raise InvalidURI(
425446
"%s URIs must not include a port number" % (SRV_SCHEME,))
426447

427-
dns_resolver = _SrvResolver(fqdn)
448+
# Use the connection timeout. connectTimeoutMS passed as a keyword
449+
# argument overrides the same option passed in the connection string.
450+
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
451+
dns_resolver = _SrvResolver(fqdn, connect_timeout=connect_timeout)
428452
nodes = dns_resolver.get_hosts()
429453
dns_options = dns_resolver.get_options()
430454
if dns_options:
431-
options = split_options(dns_options, validate, warn, normalize)
432-
if set(options) - _ALLOWED_TXT_OPTS:
455+
parsed_dns_options = split_options(
456+
dns_options, validate, warn, normalize)
457+
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
433458
raise ConfigurationError(
434459
"Only authSource and replicaSet are supported from DNS")
435-
options["ssl"] = True if validate else 'true'
460+
for opt, val in parsed_dns_options.items():
461+
if opt not in options:
462+
options[opt] = val
463+
if "ssl" not in options:
464+
options["ssl"] = True if validate else 'true'
436465
else:
437466
nodes = split_hosts(hosts, default_port=default_port)
438467

439-
if path_part:
440-
if path_part[0] == '?':
441-
opts = unquote_plus(path_part[1:])
442-
else:
443-
dbase, _, opts = map(unquote_plus, path_part.partition('?'))
444-
if '.' in dbase:
445-
dbase, collection = dbase.split('.', 1)
446-
447-
if _BAD_DB_CHARS.search(dbase):
448-
raise InvalidURI('Bad database name "%s"' % dbase)
449-
450-
if opts:
451-
options.update(split_options(opts, validate, warn, normalize))
452-
453-
if dbase is not None:
454-
dbase = unquote_plus(dbase)
455-
if collection is not None:
456-
collection = unquote_plus(collection)
457-
458468
return {
459469
'nodelist': nodes,
460470
'username': user,

setup.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,17 +318,16 @@ def build_extension(self, ext):
318318
'bson/buffer.c'])]
319319

320320
extras_require = {
321+
'encryption': ['pymongocrypt'], # For client side field level encryption.
321322
'snappy': ['python-snappy'],
323+
'srv': ["dnspython>=1.16.0,<2.0.0"],
322324
'zstd': ['zstandard'],
323-
'encryption': ['pymongocrypt'], # For client side field level encryption.
324325
}
325326
vi = sys.version_info
326327
if vi[0] == 2:
327-
extras_require.update(
328-
{'tls': ["ipaddress"], 'srv': ["dnspython>=1.8.0,<2.0.0"]})
328+
extras_require.update({'tls': ["ipaddress"]})
329329
else:
330-
extras_require.update(
331-
{'tls': [], 'srv': ["dnspython>=1.13.0,<2.0.0"]})
330+
extras_require.update({'tls': []})
332331
if sys.platform == 'win32':
333332
extras_require['gssapi'] = ["winkerberos>=0.5.0"]
334333
if vi < (2, 7, 9):

test/test_client.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
from bson.py3compat import thread
3434
from bson.son import SON
3535
from bson.tz_util import utc
36+
import pymongo
3637
from pymongo import auth, message
37-
from pymongo.common import _UUID_REPRESENTATIONS
38+
from pymongo.common import CONNECT_TIMEOUT, _UUID_REPRESENTATIONS
3839
from pymongo.command_cursor import CommandCursor
3940
from pymongo.compression_support import _HAVE_SNAPPY, _HAVE_ZSTD
4041
from pymongo.cursor import Cursor, CursorType
@@ -56,6 +57,7 @@
5657
from pymongo.server_selectors import (any_server_selector,
5758
writable_server_selector)
5859
from pymongo.server_type import SERVER_TYPE
60+
from pymongo.srv_resolver import _HAVE_DNSPYTHON
5961
from pymongo.write_concern import WriteConcern
6062
from test import (client_context,
6163
client_knobs,
@@ -70,6 +72,7 @@
7072
from test.utils import (assertRaisesExactly,
7173
connected,
7274
delay,
75+
FunctionCallRecorder,
7376
get_pool,
7477
gevent_monkey_patched,
7578
ignore_deprecations,
@@ -358,6 +361,43 @@ def test_uri_option_precedence(self):
358361
self.assertEqual(
359362
clopts.read_preference, ReadPreference.SECONDARY_PREFERRED)
360363

364+
@unittest.skipUnless(
365+
_HAVE_DNSPYTHON, "DNS-related tests need dnspython to be installed")
366+
def test_connection_timeout_ms_propagates_to_DNS_resolver(self):
367+
# Patch the resolver.
368+
from pymongo.srv_resolver import resolver
369+
patched_resolver = FunctionCallRecorder(resolver.query)
370+
pymongo.srv_resolver.resolver.query = patched_resolver
371+
def reset_resolver():
372+
pymongo.srv_resolver.resolver.query = resolver.query
373+
self.addCleanup(reset_resolver)
374+
375+
# Setup.
376+
base_uri = "mongodb+srv://test5.test.build.10gen.cc"
377+
connectTimeoutMS = 5000
378+
expected_kw_value = 5.0
379+
uri_with_timeout = base_uri + "/?connectTimeoutMS=6000"
380+
expected_uri_value = 6.0
381+
382+
def test_scenario(args, kwargs, expected_value):
383+
patched_resolver.reset()
384+
MongoClient(*args, **kwargs)
385+
for _, kw in patched_resolver.call_list():
386+
self.assertAlmostEqual(kw['lifetime'], expected_value)
387+
388+
# No timeout specified.
389+
test_scenario((base_uri,), {}, CONNECT_TIMEOUT)
390+
391+
# Timeout only specified in connection string.
392+
test_scenario((uri_with_timeout,), {}, expected_uri_value)
393+
394+
# Timeout only specified in keyword arguments.
395+
kwarg = {'connectTimeoutMS': connectTimeoutMS}
396+
test_scenario((base_uri,), kwarg, expected_kw_value)
397+
398+
# Timeout specified in both kwargs and connection string.
399+
test_scenario((uri_with_timeout,), kwarg, expected_kw_value)
400+
361401
def test_uri_security_options(self):
362402
# Ensure that we don't silently override security-related options.
363403
with self.assertRaises(InvalidURI):

test/test_server_selection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from test import client_context, unittest, IntegrationTest
3030
from test.utils import (rs_or_single_client, wait_until, EventListener,
31-
FunctionCallCounter)
31+
FunctionCallRecorder)
3232
from test.utils_selection_tests import (
3333
create_selection_tests, get_addresses, get_topology_settings_dict,
3434
make_server_description)
@@ -105,7 +105,7 @@ def test_invalid_server_selector(self):
105105

106106
@client_context.require_replica_set
107107
def test_selector_called(self):
108-
selector = FunctionCallCounter(lambda x: x)
108+
selector = FunctionCallRecorder(lambda x: x)
109109

110110
# Client setup.
111111
mongo_client = rs_or_single_client(server_selector=selector)
@@ -169,7 +169,7 @@ def test_latency_threshold_application(self):
169169

170170
@client_context.require_replica_set
171171
def test_server_selector_bypassed(self):
172-
selector = FunctionCallCounter(lambda x: x)
172+
selector = FunctionCallRecorder(lambda x: x)
173173

174174
scenario_def = {
175175
'topology_description': {

test/test_srv_polling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pymongo.srv_resolver import _HAVE_DNSPYTHON
2828
from pymongo.mongo_client import MongoClient
2929
from test import client_knobs, unittest
30-
from test.utils import wait_until, FunctionCallCounter
30+
from test.utils import wait_until, FunctionCallRecorder
3131

3232

3333
WAIT_TIME = 0.1
@@ -62,7 +62,7 @@ def mock_get_hosts_and_min_ttl(resolver, *args):
6262
return nodes, ttl
6363

6464
if self.count_resolver_calls:
65-
patch_func = FunctionCallCounter(mock_get_hosts_and_min_ttl)
65+
patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl)
6666
else:
6767
patch_func = mock_get_hosts_and_min_ttl
6868

test/utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,19 +216,28 @@ def __ne__(self, other):
216216
return not self.__eq__(other)
217217

218218

219-
class FunctionCallCounter(object):
220-
"""Class that wraps a function and keeps count of invocations."""
219+
class FunctionCallRecorder(object):
220+
"""Utility class to wrap a callable and record its invocations."""
221221
def __init__(self, function):
222222
self._function = function
223-
self._call_count = 0
223+
self._call_list = []
224224

225225
def __call__(self, *args, **kwargs):
226-
self._call_count += 1
226+
self._call_list.append((args, kwargs))
227227
return self._function(*args, **kwargs)
228228

229+
def reset(self):
230+
"""Wipes the call list."""
231+
self._call_list = []
232+
233+
def call_list(self):
234+
"""Returns a copy of the call list."""
235+
return self._call_list[:]
236+
229237
@property
230238
def call_count(self):
231-
return self._call_count
239+
"""Returns the number of times the function has been called."""
240+
return len(self._call_list)
232241

233242

234243
class TestCreator(object):

0 commit comments

Comments
 (0)