Skip to content

Commit 6bb8a1f

Browse files
authored
PYTHON-2823 Allow custom service names with srvServiceName URI option (#749)
1 parent 049daf9 commit 6bb8a1f

File tree

11 files changed

+111
-8
lines changed

11 files changed

+111
-8
lines changed

doc/changelog.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ Breaking Changes in 4.0
134134
- The ``hint`` option is now required when using ``min`` or ``max`` queries
135135
with :meth:`~pymongo.collection.Collection.find`.
136136
- ``name`` is now a required argument for the :class:`pymongo.driver_info.DriverInfo` class.
137+
- When providing a "mongodb+srv://" URI to
138+
:class:`~pymongo.mongo_client.MongoClient` constructor you can now use the
139+
``srvServiceName`` URI option to specify your own SRV service name.
137140
- :meth:`~bson.son.SON.items` now returns a ``dict_items`` object rather
138141
than a list.
139142
- Removed :meth:`bson.son.SON.iteritems`.
@@ -160,7 +163,6 @@ Breaking Changes in 4.0
160163
- ``MongoClient()`` now raises a :exc:`~pymongo.errors.ConfigurationError`
161164
when more than one URI is passed into the ``hosts`` argument.
162165

163-
164166
Notable improvements
165167
....................
166168

pymongo/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@
113113
# From the driver sessions spec.
114114
_MAX_END_SESSIONS = 10000
115115

116+
# Default value for srvServiceName
117+
SRV_SERVICE_NAME = "mongodb"
118+
116119

117120
def partition_node(node):
118121
"""Split a host:port string into (host, int(port)) pair."""
@@ -626,6 +629,7 @@ def validate_tzinfo(dummy, value):
626629
'w': validate_non_negative_int_or_basestring,
627630
'wtimeoutms': validate_non_negative_integer,
628631
'zlibcompressionlevel': validate_zlib_compression_level,
632+
'srvservicename': validate_string
629633
}
630634

631635
# Dictionary where keys are the names of URI options specific to pymongo,

pymongo/mongo_client.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,11 @@ def __init__(
329329
a Unicode-related error occurs during BSON decoding that would
330330
otherwise raise :exc:`UnicodeDecodeError`. Valid options include
331331
'strict', 'replace', and 'ignore'. Defaults to 'strict'.
332+
- ``srvServiceName`: (string) The SRV service name to use for
333+
"mongodb+srv://" URIs. Defaults to "mongodb". Use it like so::
334+
335+
MongoClient("mongodb+srv://example.com/?srvServiceName=customname")
336+
332337
333338
| **Write Concern options:**
334339
| (Only set if passed. No default values.)
@@ -499,6 +504,7 @@ def __init__(
499504
arguments.
500505
The default for `uuidRepresentation` was changed from
501506
``pythonLegacy`` to ``unspecified``.
507+
Added the ``srvServiceName`` URI and keyword argument.
502508
503509
.. versionchanged:: 3.12
504510
Added the ``server_api`` keyword argument.
@@ -644,6 +650,8 @@ def __init__(
644650
dbase = None
645651
opts = common._CaseInsensitiveDictionary()
646652
fqdn = None
653+
srv_service_name = keyword_opts.get("srvservicename", None)
654+
647655
if len([h for h in host if "/" in h]) > 1:
648656
raise ConfigurationError("host must not contain multiple MongoDB "
649657
"URIs")
@@ -659,7 +667,7 @@ def __init__(
659667
keyword_opts.cased_key("connecttimeoutms"), timeout)
660668
res = uri_parser.parse_uri(
661669
entity, port, validate=True, warn=True, normalize=False,
662-
connect_timeout=timeout)
670+
connect_timeout=timeout, srv_service_name=srv_service_name)
663671
seeds.update(res["nodelist"])
664672
username = res["username"] or username
665673
password = res["password"] or password
@@ -689,6 +697,10 @@ def __init__(
689697

690698
# Override connection string options with kwarg options.
691699
opts.update(keyword_opts)
700+
701+
if srv_service_name is None:
702+
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
703+
692704
# Handle security-option conflicts in combined options.
693705
opts = _handle_security_options(opts)
694706
# Normalize combined options.
@@ -728,6 +740,7 @@ def __init__(
728740
server_selector=options.server_selector,
729741
heartbeat_frequency=options.heartbeat_frequency,
730742
fqdn=fqdn,
743+
srv_service_name=srv_service_name,
731744
direct_connection=options.direct_connection,
732745
load_balanced=options.load_balanced,
733746
)

pymongo/monitor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def __init__(self, topology, topology_settings):
299299
self._settings = topology_settings
300300
self._seedlist = self._settings._seeds
301301
self._fqdn = self._settings.fqdn
302+
self._srv_service_name = self._settings._srv_service_name
302303

303304
def _run(self):
304305
seedlist = self._get_seedlist()
@@ -316,7 +317,10 @@ def _get_seedlist(self):
316317
Returns a list of ServerDescriptions.
317318
"""
318319
try:
319-
seedlist, ttl = _SrvResolver(self._fqdn).get_hosts_and_min_ttl()
320+
resolver = _SrvResolver(self._fqdn,
321+
self._settings.pool_options.connect_timeout,
322+
self._srv_service_name)
323+
seedlist, ttl = resolver.get_hosts_and_min_ttl()
320324
if len(seedlist) == 0:
321325
# As per the spec: this should be treated as a failure.
322326
raise Exception

pymongo/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(self,
3939
heartbeat_frequency=common.HEARTBEAT_FREQUENCY,
4040
server_selector=None,
4141
fqdn=None,
42+
srv_service_name=common.SRV_SERVICE_NAME,
4243
direct_connection=False,
4344
load_balanced=None):
4445
"""Represent MongoClient's configuration.
@@ -60,6 +61,7 @@ def __init__(self,
6061
self._server_selection_timeout = server_selection_timeout
6162
self._server_selector = server_selector
6263
self._fqdn = fqdn
64+
self._srv_service_name = srv_service_name
6365
self._heartbeat_frequency = heartbeat_frequency
6466

6567
self._direct = direct_connection

pymongo/srv_resolver.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ def _resolve(*args, **kwargs):
4747
"Did you mean to use 'mongodb://'?")
4848

4949
class _SrvResolver(object):
50-
def __init__(self, fqdn, connect_timeout=None):
50+
def __init__(self, fqdn,
51+
connect_timeout, srv_service_name):
5152
self.__fqdn = fqdn
53+
self.__srv = srv_service_name
5254
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
5355

5456
# Validate the fully qualified domain name.
@@ -83,8 +85,8 @@ def get_options(self):
8385

8486
def _resolve_uri(self, encapsulate_errors):
8587
try:
86-
results = _resolve('_mongodb._tcp.' + self.__fqdn, 'SRV',
87-
lifetime=self.__connect_timeout)
88+
results = _resolve('_' + self.__srv + '._tcp.' + self.__fqdn,
89+
'SRV', lifetime=self.__connect_timeout)
8890
except Exception as exc:
8991
if not encapsulate_errors:
9092
# Raise the original error.

pymongo/uri_parser.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from urllib.parse import unquote_plus
2222

2323
from pymongo.common import (
24+
SRV_SERVICE_NAME,
2425
get_validated_options, INTERNAL_URI_OPTION_NAME_MAP,
2526
URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary)
2627
from pymongo.errors import ConfigurationError, InvalidURI
@@ -373,7 +374,7 @@ def _check_options(nodes, options):
373374

374375

375376
def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
376-
normalize=True, connect_timeout=None):
377+
normalize=True, connect_timeout=None, srv_service_name=None):
377378
"""Parse and validate a MongoDB URI.
378379
379380
Returns a dict of the form::
@@ -405,6 +406,7 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
405406
to their internally-used names. Default: ``True``.
406407
- `connect_timeout` (optional): The maximum time in milliseconds to
407408
wait for a response from the DNS server.
409+
- 'srv_service_name` (optional): A custom SRV service name
408410
409411
.. versionchanged:: 3.9
410412
Added the ``normalize`` parameter.
@@ -468,6 +470,9 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
468470
if opts:
469471
options.update(split_options(opts, validate, warn, normalize))
470472

473+
if srv_service_name is None:
474+
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
475+
471476
if '@' in host_part:
472477
userinfo, _, hosts = host_part.rpartition('@')
473478
user, passwd = parse_userinfo(userinfo)
@@ -499,7 +504,7 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
499504
# Use the connection timeout. connectTimeoutMS passed as a keyword
500505
# argument overrides the same option passed in the connection string.
501506
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
502-
dns_resolver = _SrvResolver(fqdn, connect_timeout=connect_timeout)
507+
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name)
503508
nodes = dns_resolver.get_hosts()
504509
dns_options = dns_resolver.get_options()
505510
if dns_options:
@@ -514,6 +519,9 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
514519
options[opt] = val
515520
if "tls" not in options and "ssl" not in options:
516521
options["tls"] = True if validate else 'true'
522+
elif not is_srv and options.get("srvServiceName") is not None:
523+
raise ConfigurationError("The srvServiceName option is only allowed "
524+
"with 'mongodb+srv://' URIs")
517525
else:
518526
nodes = split_hosts(hosts, default_port=default_port)
519527

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"uri": "mongodb+srv://test4.test.build.10gen.cc/?loadBalanced=true",
3+
"seeds": [],
4+
"hosts": [],
5+
"error": true,
6+
"comment": "Should fail because no SRV records are present for this URI."
7+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"uri": "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname",
3+
"seeds": [
4+
"localhost.test.build.10gen.cc:27017",
5+
"localhost.test.build.10gen.cc:27018"
6+
],
7+
"hosts": [
8+
"localhost:27017",
9+
"localhost:27018",
10+
"localhost:27019"
11+
],
12+
"options": {
13+
"ssl": true,
14+
"srvServiceName": "customname"
15+
}
16+
}

test/test_client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,27 @@ def test_network_error_message(self):
15911591
with self.assertRaisesRegex(AutoReconnect, expected):
15921592
client.pymongo_test.test.find_one({})
15931593

1594+
@unittest.skipUnless(
1595+
_HAVE_DNSPYTHON, "DNS-related tests need dnspython to be installed")
1596+
def test_service_name_from_kwargs(self):
1597+
client = MongoClient(
1598+
'mongodb+srv://user:[email protected]',
1599+
srvServiceName='customname', connect=False)
1600+
self.assertEqual(client._topology_settings._srv_service_name,
1601+
'customname')
1602+
client = MongoClient(
1603+
'mongodb+srv://user:[email protected]'
1604+
'/?srvServiceName=shouldbeoverriden',
1605+
srvServiceName='customname', connect=False)
1606+
self.assertEqual(client._topology_settings._srv_service_name,
1607+
'customname')
1608+
client = MongoClient(
1609+
'mongodb+srv://user:[email protected]'
1610+
'/?srvServiceName=customname',
1611+
connect=False)
1612+
self.assertEqual(client._topology_settings._srv_service_name,
1613+
'customname')
1614+
15941615

15951616
class TestExhaustCursor(IntegrationTest):
15961617
"""Test that clients properly handle errors from exhaust cursors."""

0 commit comments

Comments
 (0)