Skip to content

Commit e38c2ad

Browse files
committed
refactor part 2
1 parent 8c2b368 commit e38c2ad

24 files changed

+585
-287
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
from pymongo.results import BulkWriteResult, DeleteResult
8888
from pymongo.ssl_support import get_ssl_context
8989
from pymongo.typings import _DocumentType, _DocumentTypeArg
90-
from pymongo.uri_parser import parse_host
90+
from pymongo.uri_parser_shared import parse_host
9191
from pymongo.write_concern import WriteConcern
9292

9393
if TYPE_CHECKING:

pymongo/asynchronous/mongo_client.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@
5959
cast,
6060
)
6161

62+
import pymongo.asynchronous.uri_parser
6263
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
6364
from bson.timestamp import Timestamp
64-
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
65-
from pymongo.asynchronous import client_session, database
65+
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser_shared
66+
from pymongo.asynchronous import client_session, database, uri_parser
6667
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
6768
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
6869
from pymongo.asynchronous.client_session import _EmptyServerSession
@@ -114,7 +115,7 @@
114115
_DocumentTypeArg,
115116
_Pipeline,
116117
)
117-
from pymongo.uri_parser import (
118+
from pymongo.uri_parser_shared import (
118119
SRV_SCHEME,
119120
_check_options,
120121
_handle_option_deprecations,
@@ -783,7 +784,7 @@ def __init__(
783784
# it must be a URI,
784785
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
785786
if "/" in entity:
786-
res = uri_parser._validate_uri(
787+
res = pymongo.asynchronous.uri_parser._validate_uri(
787788
entity,
788789
port,
789790
validate=True,
@@ -799,7 +800,7 @@ def __init__(
799800
opts = res["options"]
800801
fqdn = res["fqdn"]
801802
else:
802-
seeds.update(uri_parser.split_hosts(entity, self._port))
803+
seeds.update(uri_parser_shared.split_hosts(entity, self._port))
803804
if not seeds:
804805
raise ConfigurationError("need to specify at least one host")
805806

@@ -862,7 +863,7 @@ def __init__(
862863
if _IS_SYNC and connect:
863864
self._get_topology() # type: ignore[unused-coroutine]
864865

865-
def _resolve_srv(self) -> None:
866+
async def _resolve_srv(self) -> None:
866867
keyword_opts = self._resolve_srv_info["keyword_opts"]
867868
seeds = set()
868869
opts = common._CaseInsensitiveDictionary()
@@ -879,7 +880,7 @@ def _resolve_srv(self) -> None:
879880
timeout = common.validate_timeout_or_none_or_zero(
880881
keyword_opts.cased_key("connecttimeoutms"), timeout
881882
)
882-
res = uri_parser._parse_srv(
883+
res = await uri_parser._parse_srv(
883884
entity,
884885
self._port,
885886
validate=True,
@@ -892,7 +893,7 @@ def _resolve_srv(self) -> None:
892893
seeds.update(res["nodelist"])
893894
opts = res["options"]
894895
else:
895-
seeds.update(uri_parser.split_hosts(entity, self._port))
896+
seeds.update(uri_parser_shared.split_hosts(entity, self._port))
896897

897898
if not seeds:
898899
raise ConfigurationError("need to specify at least one host")
@@ -1694,7 +1695,7 @@ async def _get_topology(self) -> Topology:
16941695
"""
16951696
if not self._opened:
16961697
if self._resolve_srv_info["is_srv"]:
1697-
self._resolve_srv()
1698+
await self._resolve_srv()
16981699
self._init_background(first=True)
16991700
await self._topology.open()
17001701
async with self._lock:

pymongo/asynchronous/monitor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ async def _run(self) -> None:
395395
# Don't poll right after creation, wait 60 seconds first
396396
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
397397
return
398-
seedlist = self._get_seedlist()
398+
seedlist = await self._get_seedlist()
399399
if seedlist:
400400
self._seedlist = seedlist
401401
try:
@@ -404,7 +404,7 @@ async def _run(self) -> None:
404404
# Topology was garbage-collected.
405405
await self.close()
406406

407-
def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
407+
async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
408408
"""Poll SRV records for a seedlist.
409409
410410
Returns a list of ServerDescriptions.
@@ -415,7 +415,7 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
415415
self._settings.pool_options.connect_timeout,
416416
self._settings.srv_service_name,
417417
)
418-
seedlist, ttl = resolver.get_hosts_and_min_ttl()
418+
seedlist, ttl = await resolver.get_hosts_and_min_ttl()
419419
if len(seedlist) == 0:
420420
# As per the spec: this should be treated as a failure.
421421
raise Exception

pymongo/asynchronous/srv_resolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
5858
else:
5959
from dns.asyncresolver import Resolver
6060

61-
return await Resolver.resolve(*args, **kwargs)
61+
return await Resolver.resolve(*args, **kwargs) # type:ignore[return-value]
6262

6363

6464
_INVALID_HOST_MSG = (

pymongo/asynchronous/uri_parser.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from typing import Any, Optional
5+
from urllib.parse import unquote_plus
6+
7+
from pymongo.asynchronous.srv_resolver import _have_dnspython, _SrvResolver
8+
from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary
9+
from pymongo.errors import ConfigurationError, InvalidURI
10+
from pymongo.uri_parser_shared import (
11+
_ALLOWED_TXT_OPTS,
12+
_BAD_DB_CHARS,
13+
DEFAULT_PORT,
14+
SCHEME,
15+
SCHEME_LEN,
16+
SRV_SCHEME,
17+
SRV_SCHEME_LEN,
18+
_check_options,
19+
parse_userinfo,
20+
split_hosts,
21+
split_options,
22+
)
23+
24+
_IS_SYNC = False
25+
26+
27+
async def parse_uri(
28+
uri: str,
29+
default_port: Optional[int] = DEFAULT_PORT,
30+
validate: bool = True,
31+
warn: bool = False,
32+
normalize: bool = True,
33+
connect_timeout: Optional[float] = None,
34+
srv_service_name: Optional[str] = None,
35+
srv_max_hosts: Optional[int] = None,
36+
) -> dict[str, Any]:
37+
"""Parse and validate a MongoDB URI.
38+
39+
Returns a dict of the form::
40+
41+
{
42+
'nodelist': <list of (host, port) tuples>,
43+
'username': <username> or None,
44+
'password': <password> or None,
45+
'database': <database name> or None,
46+
'collection': <collection name> or None,
47+
'options': <dict of MongoDB URI options>,
48+
'fqdn': <fqdn of the MongoDB+SRV URI> or None
49+
}
50+
51+
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
52+
to build nodelist and options.
53+
54+
:param uri: The MongoDB URI to parse.
55+
:param default_port: The port number to use when one wasn't specified
56+
for a host in the URI.
57+
:param validate: If ``True`` (the default), validate and
58+
normalize all options. Default: ``True``.
59+
:param warn: When validating, if ``True`` then will warn
60+
the user then ignore any invalid options or values. If ``False``,
61+
validation will error when options are unsupported or values are
62+
invalid. Default: ``False``.
63+
:param normalize: If ``True``, convert names of URI options
64+
to their internally-used names. Default: ``True``.
65+
:param connect_timeout: The maximum time in milliseconds to
66+
wait for a response from the DNS server.
67+
:param srv_service_name: A custom SRV service name
68+
69+
.. versionchanged:: 4.6
70+
The delimiting slash (``/``) between hosts and connection options is now optional.
71+
For example, "mongodb://example.com?tls=true" is now a valid URI.
72+
73+
.. versionchanged:: 4.0
74+
To better follow RFC 3986, unquoted percent signs ("%") are no longer
75+
supported.
76+
77+
.. versionchanged:: 3.9
78+
Added the ``normalize`` parameter.
79+
80+
.. versionchanged:: 3.6
81+
Added support for mongodb+srv:// URIs.
82+
83+
.. versionchanged:: 3.5
84+
Return the original value of the ``readPreference`` MongoDB URI option
85+
instead of the validated read preference mode.
86+
87+
.. versionchanged:: 3.1
88+
``warn`` added so invalid options can be ignored.
89+
"""
90+
result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts)
91+
result.update(
92+
await _parse_srv(
93+
uri,
94+
default_port,
95+
validate,
96+
warn,
97+
normalize,
98+
connect_timeout,
99+
srv_service_name,
100+
srv_max_hosts,
101+
)
102+
)
103+
return result
104+
105+
106+
def _validate_uri(
107+
uri: str,
108+
default_port: Optional[int] = DEFAULT_PORT,
109+
validate: bool = True,
110+
warn: bool = False,
111+
normalize: bool = True,
112+
srv_max_hosts: Optional[int] = None,
113+
) -> dict[str, Any]:
114+
if uri.startswith(SCHEME):
115+
is_srv = False
116+
scheme_free = uri[SCHEME_LEN:]
117+
elif uri.startswith(SRV_SCHEME):
118+
if not _have_dnspython():
119+
python_path = sys.executable or "python"
120+
raise ConfigurationError(
121+
'The "dnspython" module must be '
122+
"installed to use mongodb+srv:// URIs. "
123+
"To fix this error install pymongo again:\n "
124+
"%s -m pip install pymongo>=4.3" % (python_path)
125+
)
126+
is_srv = True
127+
scheme_free = uri[SRV_SCHEME_LEN:]
128+
else:
129+
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
130+
131+
if not scheme_free:
132+
raise InvalidURI("Must provide at least one hostname or IP")
133+
134+
user = None
135+
passwd = None
136+
dbase = None
137+
collection = None
138+
options = _CaseInsensitiveDictionary()
139+
140+
host_plus_db_part, _, opts = scheme_free.partition("?")
141+
if "/" in host_plus_db_part:
142+
host_part, _, dbase = host_plus_db_part.partition("/")
143+
else:
144+
host_part = host_plus_db_part
145+
146+
if dbase:
147+
dbase = unquote_plus(dbase)
148+
if "." in dbase:
149+
dbase, collection = dbase.split(".", 1)
150+
if _BAD_DB_CHARS.search(dbase):
151+
raise InvalidURI('Bad database name "%s"' % dbase)
152+
else:
153+
dbase = None
154+
155+
if opts:
156+
options.update(split_options(opts, validate, warn, normalize))
157+
if "@" in host_part:
158+
userinfo, _, hosts = host_part.rpartition("@")
159+
user, passwd = parse_userinfo(userinfo)
160+
else:
161+
hosts = host_part
162+
163+
if "/" in hosts:
164+
raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part)
165+
166+
hosts = unquote_plus(hosts)
167+
fqdn = None
168+
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
169+
if is_srv:
170+
if options.get("directConnection"):
171+
raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs")
172+
nodes = split_hosts(hosts, default_port=None)
173+
if len(nodes) != 1:
174+
raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname")
175+
fqdn, port = nodes[0]
176+
if port is not None:
177+
raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number")
178+
elif not is_srv and options.get("srvServiceName") is not None:
179+
raise ConfigurationError(
180+
"The srvServiceName option is only allowed with 'mongodb+srv://' URIs"
181+
)
182+
elif not is_srv and srv_max_hosts:
183+
raise ConfigurationError(
184+
"The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs"
185+
)
186+
else:
187+
nodes = split_hosts(hosts, default_port=default_port)
188+
189+
_check_options(nodes, options)
190+
191+
return {
192+
"nodelist": nodes,
193+
"username": user,
194+
"password": passwd,
195+
"database": dbase,
196+
"collection": collection,
197+
"options": options,
198+
"fqdn": fqdn,
199+
}
200+
201+
202+
async def _parse_srv(
203+
uri: str,
204+
default_port: Optional[int] = DEFAULT_PORT,
205+
validate: bool = True,
206+
warn: bool = False,
207+
normalize: bool = True,
208+
connect_timeout: Optional[float] = None,
209+
srv_service_name: Optional[str] = None,
210+
srv_max_hosts: Optional[int] = None,
211+
) -> dict[str, Any]:
212+
if uri.startswith(SCHEME):
213+
is_srv = False
214+
scheme_free = uri[SCHEME_LEN:]
215+
else:
216+
is_srv = True
217+
scheme_free = uri[SRV_SCHEME_LEN:]
218+
219+
options = _CaseInsensitiveDictionary()
220+
221+
host_plus_db_part, _, opts = scheme_free.partition("?")
222+
if "/" in host_plus_db_part:
223+
host_part, _, _ = host_plus_db_part.partition("/")
224+
else:
225+
host_part = host_plus_db_part
226+
227+
if opts:
228+
options.update(split_options(opts, validate, warn, normalize))
229+
if srv_service_name is None:
230+
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
231+
if "@" in host_part:
232+
_, _, hosts = host_part.rpartition("@")
233+
else:
234+
hosts = host_part
235+
236+
hosts = unquote_plus(hosts)
237+
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
238+
if is_srv:
239+
nodes = split_hosts(hosts, default_port=None)
240+
fqdn, port = nodes[0]
241+
242+
# Use the connection timeout. connectTimeoutMS passed as a keyword
243+
# argument overrides the same option passed in the connection string.
244+
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
245+
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
246+
nodes = await dns_resolver.get_hosts()
247+
dns_options = await dns_resolver.get_options()
248+
if dns_options:
249+
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
250+
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
251+
raise ConfigurationError(
252+
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
253+
)
254+
for opt, val in parsed_dns_options.items():
255+
if opt not in options:
256+
options[opt] = val
257+
if options.get("loadBalanced") and srv_max_hosts:
258+
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
259+
if options.get("replicaSet") and srv_max_hosts:
260+
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
261+
if "tls" not in options and "ssl" not in options:
262+
options["tls"] = True if validate else "true"
263+
else:
264+
nodes = split_hosts(hosts, default_port=default_port)
265+
266+
_check_options(nodes, options)
267+
268+
return {
269+
"nodelist": nodes,
270+
"options": options,
271+
}

0 commit comments

Comments
 (0)