Skip to content

Commit ed50141

Browse files
committed
some refactoring to reduce code duplication
1 parent 0f64689 commit ed50141

File tree

2 files changed

+102
-108
lines changed

2 files changed

+102
-108
lines changed

pymongo/asynchronous/mongo_client.py

Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ def __init__(
762762
# Parse options passed as kwargs.
763763
keyword_opts = common._CaseInsensitiveDictionary(kwargs)
764764
keyword_opts["document_class"] = doc_class
765+
self._resolve_uri_info = {"keyword_opts": keyword_opts}
765766

766767
seeds = set()
767768
username = None
@@ -814,25 +815,13 @@ def __init__(
814815
keyword_opts["tz_aware"] = tz_aware
815816
keyword_opts["connect"] = connect
816817

817-
# Handle deprecated options in kwarg options.
818-
keyword_opts = _handle_option_deprecations(keyword_opts)
819-
# Validate kwarg options.
820-
keyword_opts = common._CaseInsensitiveDictionary(
821-
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
822-
)
823-
824-
# Override connection string options with kwarg options.
825-
opts.update(keyword_opts)
818+
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
826819

827820
if srv_service_name is None:
828821
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
829822

830823
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
831-
# Handle security-option conflicts in combined options.
832-
opts = _handle_security_options(opts)
833-
# Normalize combined options.
834-
opts = _normalize_options(opts)
835-
_check_options(seeds, opts)
824+
opts = self._normalize_and_validate_options(opts, seeds)
836825

837826
# Username and password passed as kwargs override user info in URI.
838827
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
@@ -872,15 +861,17 @@ def __init__(
872861
self._closed = False
873862
self._init_background()
874863

875-
self._for_resolve_uri = {
876-
"username": username,
877-
"password": password,
878-
"dbase": dbase,
879-
"fqdn": fqdn,
880-
"pool_class": pool_class,
881-
"monitor_class": monitor_class,
882-
"condition_class": condition_class,
883-
}
864+
self._resolve_uri_info.update(
865+
{
866+
"username": username,
867+
"password": password,
868+
"dbase": dbase,
869+
"fqdn": fqdn,
870+
"pool_class": pool_class,
871+
"monitor_class": monitor_class,
872+
"condition_class": condition_class,
873+
}
874+
)
884875
if _IS_SYNC and connect:
885876
self._get_topology() # type: ignore[unused-coroutine]
886877

@@ -896,17 +887,16 @@ def __init__(
896887
# This will be used later if we fork.
897888
AsyncMongoClient._clients[self._topology._topology_id] = self
898889

890+
def _normalize_and_validate_options(self, opts, seeds):
891+
# Handle security-option conflicts in combined options.
892+
opts = _handle_security_options(opts)
893+
# Normalize combined options.
894+
opts = _normalize_options(opts)
895+
_check_options(seeds, opts)
896+
return opts
897+
899898
def _resolve_uri(self):
900-
keyword_opts = common._CaseInsensitiveDictionary(self._init_kwargs)
901-
for i in [
902-
"_pool_class",
903-
"_monitor_class",
904-
"_condition_class",
905-
"host",
906-
"port",
907-
"type_registry",
908-
]:
909-
keyword_opts.pop(i, None)
899+
keyword_opts = self._resolve_uri_info["keyword_opts"]
910900
seeds = set()
911901
opts = common._CaseInsensitiveDictionary()
912902
srv_service_name = keyword_opts.get("srvservicename")
@@ -957,31 +947,19 @@ def _resolve_uri(self):
957947
keyword_opts["tz_aware"] = tz_aware
958948
keyword_opts["connect"] = connect
959949

960-
# Handle deprecated options in kwarg options.
961-
keyword_opts = _handle_option_deprecations(keyword_opts)
962-
# Validate kwarg options.
963-
keyword_opts = common._CaseInsensitiveDictionary(
964-
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
965-
)
966-
967-
# Override connection string options with kwarg options.
968-
opts.update(keyword_opts)
950+
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
969951

970952
if srv_service_name is None:
971953
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
972954

973955
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
974-
# Handle security-option conflicts in combined options.
975-
opts = _handle_security_options(opts)
976-
# Normalize combined options.
977-
opts = _normalize_options(opts)
978-
_check_options(seeds, opts)
956+
opts = self._normalize_and_validate_opts(opts, seeds)
979957

980958
# Username and password passed as kwargs override user info in URI.
981-
username = opts.get("username", self._for_resolve_uri["username"])
982-
password = opts.get("password", self._for_resolve_uri["password"])
959+
username = opts.get("username", self._resolve_uri_info["username"])
960+
password = opts.get("password", self._resolve_uri_info["password"])
983961
self._options = ClientOptions(
984-
username, password, self._for_resolve_uri["dbase"], opts, _IS_SYNC
962+
username, password, self._resolve_uri_info["dbase"], opts, _IS_SYNC
985963
)
986964

987965
self._event_listeners = self._options.pool_options._event_listeners
@@ -995,15 +973,15 @@ def _resolve_uri(self):
995973
self._topology_settings = TopologySettings(
996974
seeds=seeds,
997975
replica_set_name=self._options.replica_set_name,
998-
pool_class=self._for_resolve_uri["pool_class"],
976+
pool_class=self._resolve_uri_info["pool_class"],
999977
pool_options=self._options.pool_options,
1000-
monitor_class=self._for_resolve_uri["monitor_class"],
1001-
condition_class=self._for_resolve_uri["condition_class"],
978+
monitor_class=self._resolve_uri_info["monitor_class"],
979+
condition_class=self._resolve_uri_info["condition_class"],
1002980
local_threshold_ms=self._options.local_threshold_ms,
1003981
server_selection_timeout=self._options.server_selection_timeout,
1004982
server_selector=self._options.server_selector,
1005983
heartbeat_frequency=self._options.heartbeat_frequency,
1006-
fqdn=self._for_resolve_uri["fqdn"],
984+
fqdn=self._resolve_uri_info["fqdn"],
1007985
direct_connection=self._options.direct_connection,
1008986
load_balanced=self._options.load_balanced,
1009987
srv_service_name=srv_service_name,
@@ -1013,6 +991,25 @@ def _resolve_uri(self):
1013991

1014992
self._topology = Topology(self._topology_settings)
1015993

994+
def _normalize_and_validate_opts(self, opts, seeds):
995+
# Handle security-option conflicts in combined options.
996+
opts = _handle_security_options(opts)
997+
# Normalize combined options.
998+
opts = _normalize_options(opts)
999+
_check_options(seeds, opts)
1000+
return opts
1001+
1002+
def _validate_kwargs_and_update_opts(self, keyword_opts, opts):
1003+
# Handle deprecated options in kwarg options.
1004+
keyword_opts = _handle_option_deprecations(keyword_opts)
1005+
# Validate kwarg options.
1006+
keyword_opts = common._CaseInsensitiveDictionary(
1007+
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
1008+
)
1009+
# Override connection string options with kwarg options.
1010+
opts.update(keyword_opts)
1011+
return opts
1012+
10161013
async def aconnect(self) -> None:
10171014
"""Explicitly connect to MongoDB asynchronously instead of on the first operation."""
10181015
await self._get_topology()

pymongo/synchronous/mongo_client.py

Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ def __init__(
760760
# Parse options passed as kwargs.
761761
keyword_opts = common._CaseInsensitiveDictionary(kwargs)
762762
keyword_opts["document_class"] = doc_class
763+
self._resolve_uri_info = {"keyword_opts": keyword_opts}
763764

764765
seeds = set()
765766
username = None
@@ -812,25 +813,13 @@ def __init__(
812813
keyword_opts["tz_aware"] = tz_aware
813814
keyword_opts["connect"] = connect
814815

815-
# Handle deprecated options in kwarg options.
816-
keyword_opts = _handle_option_deprecations(keyword_opts)
817-
# Validate kwarg options.
818-
keyword_opts = common._CaseInsensitiveDictionary(
819-
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
820-
)
821-
822-
# Override connection string options with kwarg options.
823-
opts.update(keyword_opts)
816+
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
824817

825818
if srv_service_name is None:
826819
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
827820

828821
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
829-
# Handle security-option conflicts in combined options.
830-
opts = _handle_security_options(opts)
831-
# Normalize combined options.
832-
opts = _normalize_options(opts)
833-
_check_options(seeds, opts)
822+
opts = self._normalize_and_validate_options(opts, seeds)
834823

835824
# Username and password passed as kwargs override user info in URI.
836825
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
@@ -870,15 +859,17 @@ def __init__(
870859
self._closed = False
871860
self._init_background()
872861

873-
self._for_resolve_uri = {
874-
"username": username,
875-
"password": password,
876-
"dbase": dbase,
877-
"fqdn": fqdn,
878-
"pool_class": pool_class,
879-
"monitor_class": monitor_class,
880-
"condition_class": condition_class,
881-
}
862+
self._resolve_uri_info.update(
863+
{
864+
"username": username,
865+
"password": password,
866+
"dbase": dbase,
867+
"fqdn": fqdn,
868+
"pool_class": pool_class,
869+
"monitor_class": monitor_class,
870+
"condition_class": condition_class,
871+
}
872+
)
882873
if _IS_SYNC and connect:
883874
self._get_topology() # type: ignore[unused-coroutine]
884875

@@ -894,17 +885,16 @@ def __init__(
894885
# This will be used later if we fork.
895886
MongoClient._clients[self._topology._topology_id] = self
896887

888+
def _normalize_and_validate_options(self, opts, seeds):
889+
# Handle security-option conflicts in combined options.
890+
opts = _handle_security_options(opts)
891+
# Normalize combined options.
892+
opts = _normalize_options(opts)
893+
_check_options(seeds, opts)
894+
return opts
895+
897896
def _resolve_uri(self):
898-
keyword_opts = common._CaseInsensitiveDictionary(self._init_kwargs)
899-
for i in [
900-
"_pool_class",
901-
"_monitor_class",
902-
"_condition_class",
903-
"host",
904-
"port",
905-
"type_registry",
906-
]:
907-
keyword_opts.pop(i, None)
897+
keyword_opts = self._resolve_uri_info["keyword_opts"]
908898
seeds = set()
909899
opts = common._CaseInsensitiveDictionary()
910900
srv_service_name = keyword_opts.get("srvservicename")
@@ -955,31 +945,19 @@ def _resolve_uri(self):
955945
keyword_opts["tz_aware"] = tz_aware
956946
keyword_opts["connect"] = connect
957947

958-
# Handle deprecated options in kwarg options.
959-
keyword_opts = _handle_option_deprecations(keyword_opts)
960-
# Validate kwarg options.
961-
keyword_opts = common._CaseInsensitiveDictionary(
962-
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
963-
)
964-
965-
# Override connection string options with kwarg options.
966-
opts.update(keyword_opts)
948+
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
967949

968950
if srv_service_name is None:
969951
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
970952

971953
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
972-
# Handle security-option conflicts in combined options.
973-
opts = _handle_security_options(opts)
974-
# Normalize combined options.
975-
opts = _normalize_options(opts)
976-
_check_options(seeds, opts)
954+
opts = self._normalize_and_validate_opts(opts, seeds)
977955

978956
# Username and password passed as kwargs override user info in URI.
979-
username = opts.get("username", self._for_resolve_uri["username"])
980-
password = opts.get("password", self._for_resolve_uri["password"])
957+
username = opts.get("username", self._resolve_uri_info["username"])
958+
password = opts.get("password", self._resolve_uri_info["password"])
981959
self._options = ClientOptions(
982-
username, password, self._for_resolve_uri["dbase"], opts, _IS_SYNC
960+
username, password, self._resolve_uri_info["dbase"], opts, _IS_SYNC
983961
)
984962

985963
self._event_listeners = self._options.pool_options._event_listeners
@@ -993,15 +971,15 @@ def _resolve_uri(self):
993971
self._topology_settings = TopologySettings(
994972
seeds=seeds,
995973
replica_set_name=self._options.replica_set_name,
996-
pool_class=self._for_resolve_uri["pool_class"],
974+
pool_class=self._resolve_uri_info["pool_class"],
997975
pool_options=self._options.pool_options,
998-
monitor_class=self._for_resolve_uri["monitor_class"],
999-
condition_class=self._for_resolve_uri["condition_class"],
976+
monitor_class=self._resolve_uri_info["monitor_class"],
977+
condition_class=self._resolve_uri_info["condition_class"],
1000978
local_threshold_ms=self._options.local_threshold_ms,
1001979
server_selection_timeout=self._options.server_selection_timeout,
1002980
server_selector=self._options.server_selector,
1003981
heartbeat_frequency=self._options.heartbeat_frequency,
1004-
fqdn=self._for_resolve_uri["fqdn"],
982+
fqdn=self._resolve_uri_info["fqdn"],
1005983
direct_connection=self._options.direct_connection,
1006984
load_balanced=self._options.load_balanced,
1007985
srv_service_name=srv_service_name,
@@ -1011,6 +989,25 @@ def _resolve_uri(self):
1011989

1012990
self._topology = Topology(self._topology_settings)
1013991

992+
def _normalize_and_validate_opts(self, opts, seeds):
993+
# Handle security-option conflicts in combined options.
994+
opts = _handle_security_options(opts)
995+
# Normalize combined options.
996+
opts = _normalize_options(opts)
997+
_check_options(seeds, opts)
998+
return opts
999+
1000+
def _validate_kwargs_and_update_opts(self, keyword_opts, opts):
1001+
# Handle deprecated options in kwarg options.
1002+
keyword_opts = _handle_option_deprecations(keyword_opts)
1003+
# Validate kwarg options.
1004+
keyword_opts = common._CaseInsensitiveDictionary(
1005+
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
1006+
)
1007+
# Override connection string options with kwarg options.
1008+
opts.update(keyword_opts)
1009+
return opts
1010+
10141011
def _connect(self) -> None:
10151012
"""Explicitly connect to MongoDB synchronously instead of on the first operation."""
10161013
self._get_topology()

0 commit comments

Comments
 (0)