diff --git a/pymongo/asynchronous/auth.py b/pymongo/asynchronous/auth.py index 48ce4bbd39..b1e6d0125b 100644 --- a/pymongo/asynchronous/auth.py +++ b/pymongo/asynchronous/auth.py @@ -38,6 +38,7 @@ _authenticate_oidc, _get_authenticator, ) +from pymongo.asynchronous.helpers import _getaddrinfo from pymongo.auth_shared import ( MongoCredential, _authenticate_scram_start, @@ -177,15 +178,22 @@ def _auth_key(nonce: str, username: str, password: str) -> str: return md5hash.hexdigest() -def _canonicalize_hostname(hostname: str, option: str | bool) -> str: +async def _canonicalize_hostname(hostname: str, option: str | bool) -> str: """Canonicalize hostname following MIT-krb5 behavior.""" # https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520 if option in [False, "none"]: return hostname - af, socktype, proto, canonname, sockaddr = socket.getaddrinfo( - hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME - )[0] + af, socktype, proto, canonname, sockaddr = ( + await _getaddrinfo( + hostname, + None, + family=0, + type=0, + proto=socket.IPPROTO_TCP, + flags=socket.AI_CANONNAME, + ) + )[0] # type: ignore[index] # For forward just to resolve the cname as dns.lookup() will not return it. if option == "forward": @@ -213,7 +221,7 @@ async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnecti # Starting here and continuing through the while loop below - establish # the security context. See RFC 4752, Section 3.1, first paragraph. host = props.service_host or conn.address[0] - host = _canonicalize_hostname(host, props.canonicalize_host_name) + host = await _canonicalize_hostname(host, props.canonicalize_host_name) service = props.service_name + "@" + host if props.service_realm is not None: service = service + "@" + props.service_realm diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 1ac8b6630f..d519e8749c 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -15,7 +15,9 @@ """Miscellaneous pieces that need to be synchronized.""" from __future__ import annotations +import asyncio import builtins +import socket import sys from typing import ( Any, @@ -68,6 +70,24 @@ async def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) +async def _getaddrinfo( + host: Any, port: Any, **kwargs: Any +) -> list[ + tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] +]: + if not _IS_SYNC: + loop = asyncio.get_running_loop() + return await loop.getaddrinfo(host, port, **kwargs) # type: ignore[return-value] + else: + return socket.getaddrinfo(host, port, **kwargs) + + if sys.version_info >= (3, 10): anext = builtins.anext aiter = builtins.aiter diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 5dc5675a0a..bf2f2b4946 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -40,7 +40,7 @@ from bson import DEFAULT_CODEC_OPTIONS from pymongo import _csot, helpers_shared from pymongo.asynchronous.client_session import _validate_session_write_concern -from pymongo.asynchronous.helpers import _handle_reauth +from pymongo.asynchronous.helpers import _getaddrinfo, _handle_reauth from pymongo.asynchronous.network import command, receive_message from pymongo.common import ( MAX_BSON_SIZE, @@ -783,7 +783,7 @@ def __repr__(self) -> str: ) -def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: +async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: """Given (host, port) and PoolOptions, connect and return a socket object. Can raise socket.error. @@ -814,7 +814,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket family = socket.AF_UNSPEC err = None - for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): + for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] af, socktype, proto, dummy, sa = res # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 @@ -863,7 +863,7 @@ async def _configured_socket( Sets socket's SSL and timeout options. """ - sock = _create_connection(address, options) + sock = await _create_connection(address, options) ssl_context = options._ssl_context if ssl_context is None: diff --git a/pymongo/synchronous/auth.py b/pymongo/synchronous/auth.py index 0e51ff8b7f..56860eff3b 100644 --- a/pymongo/synchronous/auth.py +++ b/pymongo/synchronous/auth.py @@ -45,6 +45,7 @@ _authenticate_oidc, _get_authenticator, ) +from pymongo.synchronous.helpers import _getaddrinfo if TYPE_CHECKING: from pymongo.hello import Hello @@ -180,9 +181,16 @@ def _canonicalize_hostname(hostname: str, option: str | bool) -> str: if option in [False, "none"]: return hostname - af, socktype, proto, canonname, sockaddr = socket.getaddrinfo( - hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME - )[0] + af, socktype, proto, canonname, sockaddr = ( + _getaddrinfo( + hostname, + None, + family=0, + type=0, + proto=socket.IPPROTO_TCP, + flags=socket.AI_CANONNAME, + ) + )[0] # type: ignore[index] # For forward just to resolve the cname as dns.lookup() will not return it. if option == "forward": diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 064583dad3..f800e7dcc8 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -15,7 +15,9 @@ """Miscellaneous pieces that need to be synchronized.""" from __future__ import annotations +import asyncio import builtins +import socket import sys from typing import ( Any, @@ -68,6 +70,24 @@ def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) +def _getaddrinfo( + host: Any, port: Any, **kwargs: Any +) -> list[ + tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] +]: + if not _IS_SYNC: + loop = asyncio.get_running_loop() + return loop.getaddrinfo(host, port, **kwargs) # type: ignore[return-value] + else: + return socket.getaddrinfo(host, port, **kwargs) + + if sys.version_info >= (3, 10): next = builtins.next iter = builtins.iter diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 1a155c82d7..05f930d480 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -84,7 +84,7 @@ from pymongo.socket_checker import SocketChecker from pymongo.ssl_support import HAS_SNI, SSLError from pymongo.synchronous.client_session import _validate_session_write_concern -from pymongo.synchronous.helpers import _handle_reauth +from pymongo.synchronous.helpers import _getaddrinfo, _handle_reauth from pymongo.synchronous.network import command, receive_message if TYPE_CHECKING: @@ -812,7 +812,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket family = socket.AF_UNSPEC err = None - for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): + for res in _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] af, socktype, proto, dummy, sa = res # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 diff --git a/test/asynchronous/test_auth.py b/test/asynchronous/test_auth.py index 08dc4d7247..7172152d69 100644 --- a/test/asynchronous/test_auth.py +++ b/test/asynchronous/test_auth.py @@ -275,10 +275,10 @@ async def test_gssapi_threaded(self): async def test_gssapi_canonicalize_host_name(self): # Test the low level method. assert GSSAPI_HOST is not None - result = _canonicalize_hostname(GSSAPI_HOST, "forward") + result = await _canonicalize_hostname(GSSAPI_HOST, "forward") if "compute-1.amazonaws.com" not in result: self.assertEqual(result, GSSAPI_HOST) - result = _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse") + result = await _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse") self.assertEqual(result, GSSAPI_HOST) # Use the equivalent named CANONICALIZE_HOST_NAME.