Skip to content

Commit ead780a

Browse files
committed
WIP (not cleaned up)
1 parent 38f97a3 commit ead780a

File tree

5 files changed

+404
-48
lines changed

5 files changed

+404
-48
lines changed

pymongo/asynchronous/mongo_client.py

Lines changed: 141 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -742,13 +742,15 @@ def __init__(
742742
**kwargs,
743743
}
744744

745-
if host is None:
746-
host = self.HOST
747-
if isinstance(host, str):
748-
host = [host]
749-
if port is None:
750-
port = self.PORT
751-
if not isinstance(port, int):
745+
self._host = host
746+
self._port = port
747+
if self._host is None:
748+
self._host = self.HOST
749+
if isinstance(self._host, str):
750+
self._host = [self._host]
751+
if self._port is None:
752+
self._port = self.PORT
753+
if not isinstance(self._port, int):
752754
raise TypeError(f"port must be an instance of int, not {type(port)}")
753755

754756
# _pool_class, _monitor_class, and _condition_class are for deep
@@ -769,26 +771,19 @@ def __init__(
769771
fqdn = None
770772
srv_service_name = keyword_opts.get("srvservicename")
771773
srv_max_hosts = keyword_opts.get("srvmaxhosts")
772-
if len([h for h in host if "/" in h]) > 1:
774+
if len([h for h in self._host if "/" in h]) > 1:
773775
raise ConfigurationError("host must not contain multiple MongoDB URIs")
774-
for entity in host:
776+
for entity in self._host:
775777
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
776778
# it must be a URI,
777779
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
778780
if "/" in entity:
779-
# Determine connection timeout from kwargs.
780-
timeout = keyword_opts.get("connecttimeoutms")
781-
if timeout is not None:
782-
timeout = common.validate_timeout_or_none_or_zero(
783-
keyword_opts.cased_key("connecttimeoutms"), timeout
784-
)
785781
res = uri_parser.parse_uri(
786782
entity,
787783
port,
788784
validate=True,
789785
warn=True,
790786
normalize=False,
791-
connect_timeout=timeout,
792787
srv_service_name=srv_service_name,
793788
srv_max_hosts=srv_max_hosts,
794789
)
@@ -799,7 +794,7 @@ def __init__(
799794
opts = res["options"]
800795
fqdn = res["fqdn"]
801796
else:
802-
seeds.update(uri_parser.split_hosts(entity, port))
797+
seeds.update(uri_parser.split_hosts(entity, self._port))
803798
if not seeds:
804799
raise ConfigurationError("need to specify at least one host")
805800

@@ -895,6 +890,134 @@ def __init__(
895890
# This will be used later if we fork.
896891
AsyncMongoClient._clients[self._topology._topology_id] = self
897892

893+
self._for_resolve_uri = {
894+
"username": username,
895+
"password": password,
896+
"srv_service_name": srv_service_name,
897+
"srv_max_hosts": srv_max_hosts,
898+
"fqdn": fqdn,
899+
"pool_class": pool_class,
900+
"monitor_class": monitor_class,
901+
"condition_class": condition_class,
902+
}
903+
904+
def _resolve_uri(self):
905+
keyword_opts = common._CaseInsensitiveDictionary(self._init_kwargs)
906+
for i in [
907+
"_pool_class",
908+
"_monitor_class",
909+
"_condition_class",
910+
"host",
911+
"port",
912+
"type_registry",
913+
]:
914+
keyword_opts.pop(i, None)
915+
seeds = set()
916+
opts = common._CaseInsensitiveDictionary()
917+
srv_service_name = keyword_opts.get("srvservicename")
918+
srv_max_hosts = keyword_opts.get("srvmaxhosts")
919+
for entity in self._host:
920+
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
921+
# it must be a URI,
922+
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
923+
if "/" in entity:
924+
# Determine connection timeout from kwargs.
925+
timeout = keyword_opts.get("connecttimeoutms")
926+
if timeout is not None:
927+
timeout = common.validate_timeout_or_none_or_zero(
928+
keyword_opts.cased_key("connecttimeoutms"), timeout
929+
)
930+
res = uri_parser.parse_uri_lookups(
931+
entity,
932+
self._port,
933+
validate=True,
934+
warn=True,
935+
normalize=False,
936+
connect_timeout=timeout,
937+
srv_service_name=srv_service_name,
938+
srv_max_hosts=srv_max_hosts,
939+
)
940+
seeds.update(res["nodelist"])
941+
opts = res["options"]
942+
else:
943+
seeds.update(uri_parser.split_hosts(entity, self._port))
944+
945+
if not seeds:
946+
raise ConfigurationError("need to specify at least one host")
947+
948+
for hostname in [node[0] for node in seeds]:
949+
if _detect_external_db(hostname):
950+
break
951+
952+
# Add options with named keyword arguments to the parsed kwarg options.
953+
tz_aware = keyword_opts["tz_aware"]
954+
connect = keyword_opts["connect"]
955+
if tz_aware is None:
956+
tz_aware = opts.get("tz_aware", False)
957+
if connect is None:
958+
# Default to connect=True unless on a FaaS system, which might use fork.
959+
from pymongo.pool_options import _is_faas
960+
961+
connect = opts.get("connect", not _is_faas())
962+
keyword_opts["tz_aware"] = tz_aware
963+
keyword_opts["connect"] = connect
964+
965+
# Handle deprecated options in kwarg options.
966+
keyword_opts = _handle_option_deprecations(keyword_opts)
967+
# Validate kwarg options.
968+
keyword_opts = common._CaseInsensitiveDictionary(
969+
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
970+
)
971+
972+
# Override connection string options with kwarg options.
973+
opts.update(keyword_opts)
974+
975+
if srv_service_name is None:
976+
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
977+
978+
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
979+
# Handle security-option conflicts in combined options.
980+
opts = _handle_security_options(opts)
981+
# Normalize combined options.
982+
opts = _normalize_options(opts)
983+
_check_options(seeds, opts)
984+
985+
# Username and password passed as kwargs override user info in URI.
986+
username = opts.get("username", self._for_resolve_uri["username"])
987+
password = opts.get("password", self._for_resolve_uri["password"])
988+
self._options = ClientOptions(
989+
username, password, self._default_database_name, opts, _IS_SYNC
990+
)
991+
992+
self._event_listeners = self._options.pool_options._event_listeners
993+
super().__init__(
994+
self._options.codec_options,
995+
self._options.read_preference,
996+
self._options.write_concern,
997+
self._options.read_concern,
998+
)
999+
1000+
self._topology_settings = TopologySettings(
1001+
seeds=seeds,
1002+
replica_set_name=self._options.replica_set_name,
1003+
pool_class=self._for_resolve_uri["pool_class"],
1004+
pool_options=self._options.pool_options,
1005+
monitor_class=self._for_resolve_uri["monitor_class"],
1006+
condition_class=self._for_resolve_uri["condition_class"],
1007+
local_threshold_ms=self._options.local_threshold_ms,
1008+
server_selection_timeout=self._options.server_selection_timeout,
1009+
server_selector=self._options.server_selector,
1010+
heartbeat_frequency=self._options.heartbeat_frequency,
1011+
fqdn=self._for_resolve_uri["fqdn"],
1012+
direct_connection=self._options.direct_connection,
1013+
load_balanced=self._options.load_balanced,
1014+
srv_service_name=srv_service_name,
1015+
srv_max_hosts=srv_max_hosts,
1016+
server_monitoring_mode=self._options.server_monitoring_mode,
1017+
)
1018+
1019+
self._topology = Topology(self._topology_settings)
1020+
8981021
async def aconnect(self) -> None:
8991022
"""Explicitly connect to MongoDB asynchronously instead of on the first operation."""
9001023
await self._get_topology()
@@ -1582,6 +1705,7 @@ async def _get_topology(self) -> Topology:
15821705
launches the connection process in the background.
15831706
"""
15841707
if not self._opened:
1708+
self._resolve_uri()
15851709
await self._topology.open()
15861710
async with self._lock:
15871711
self._kill_cursors_executor.open()

0 commit comments

Comments
 (0)