Skip to content
18 changes: 13 additions & 5 deletions pymongo/asynchronous/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_authenticate_oidc,
_get_authenticator,
)
from pymongo.asynchronous.helpers import _getaddrinfo
from pymongo.auth_shared import (
MongoCredential,
_authenticate_scram_start,
Expand Down Expand Up @@ -177,15 +178,22 @@ def _auth_key(nonce: str, username: str, password: str) -> str:
return md5hash.hexdigest()


def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
async def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
"""Canonicalize hostname following MIT-krb5 behavior."""
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
if option in [False, "none"]:
return hostname

af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
)[0]
af, socktype, proto, canonname, sockaddr = (
await _getaddrinfo(
hostname,
None,
family=0,
type=0,
proto=socket.IPPROTO_TCP,
flags=socket.AI_CANONNAME,
)
)[0] # type: ignore[index]

# For forward just to resolve the cname as dns.lookup() will not return it.
if option == "forward":
Expand Down Expand Up @@ -213,7 +221,7 @@ async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnecti
# Starting here and continuing through the while loop below - establish
# the security context. See RFC 4752, Section 3.1, first paragraph.
host = props.service_host or conn.address[0]
host = _canonicalize_hostname(host, props.canonicalize_host_name)
host = await _canonicalize_hostname(host, props.canonicalize_host_name)
service = props.service_name + "@" + host
if props.service_realm is not None:
service = service + "@" + props.service_realm
Expand Down
20 changes: 20 additions & 0 deletions pymongo/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""Miscellaneous pieces that need to be synchronized."""
from __future__ import annotations

import asyncio
import builtins
import socket
import sys
from typing import (
Any,
Expand Down Expand Up @@ -68,6 +70,24 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
return cast(F, inner)


async def _getaddrinfo(
host: Any, port: Any, **kwargs: Any
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if not _IS_SYNC:
loop = asyncio.get_running_loop()
return await loop.getaddrinfo(host, port, **kwargs) # type: ignore[return-value]
else:
return socket.getaddrinfo(host, port, **kwargs)


if sys.version_info >= (3, 10):
anext = builtins.anext
aiter = builtins.aiter
Expand Down
8 changes: 4 additions & 4 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from bson import DEFAULT_CODEC_OPTIONS
from pymongo import _csot, helpers_shared
from pymongo.asynchronous.client_session import _validate_session_write_concern
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.asynchronous.helpers import _getaddrinfo, _handle_reauth
from pymongo.asynchronous.network import command, receive_message
from pymongo.common import (
MAX_BSON_SIZE,
Expand Down Expand Up @@ -783,7 +783,7 @@ def __repr__(self) -> str:
)


def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
"""Given (host, port) and PoolOptions, connect and return a socket object.

Can raise socket.error.
Expand Down Expand Up @@ -814,7 +814,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
family = socket.AF_UNSPEC

err = None
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
Expand Down Expand Up @@ -863,7 +863,7 @@ async def _configured_socket(

Sets socket's SSL and timeout options.
"""
sock = _create_connection(address, options)
sock = await _create_connection(address, options)
ssl_context = options._ssl_context

if ssl_context is None:
Expand Down
14 changes: 11 additions & 3 deletions pymongo/synchronous/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_authenticate_oidc,
_get_authenticator,
)
from pymongo.synchronous.helpers import _getaddrinfo

if TYPE_CHECKING:
from pymongo.hello import Hello
Expand Down Expand Up @@ -180,9 +181,16 @@ def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
if option in [False, "none"]:
return hostname

af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
)[0]
af, socktype, proto, canonname, sockaddr = (
_getaddrinfo(
hostname,
None,
family=0,
type=0,
proto=socket.IPPROTO_TCP,
flags=socket.AI_CANONNAME,
)
)[0] # type: ignore[index]

# For forward just to resolve the cname as dns.lookup() will not return it.
if option == "forward":
Expand Down
20 changes: 20 additions & 0 deletions pymongo/synchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""Miscellaneous pieces that need to be synchronized."""
from __future__ import annotations

import asyncio
import builtins
import socket
import sys
from typing import (
Any,
Expand Down Expand Up @@ -68,6 +70,24 @@ def inner(*args: Any, **kwargs: Any) -> Any:
return cast(F, inner)


def _getaddrinfo(
host: Any, port: Any, **kwargs: Any
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if not _IS_SYNC:
loop = asyncio.get_running_loop()
return loop.getaddrinfo(host, port, **kwargs) # type: ignore[return-value]
else:
return socket.getaddrinfo(host, port, **kwargs)


if sys.version_info >= (3, 10):
next = builtins.next
iter = builtins.iter
Expand Down
4 changes: 2 additions & 2 deletions pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from pymongo.socket_checker import SocketChecker
from pymongo.ssl_support import HAS_SNI, SSLError
from pymongo.synchronous.client_session import _validate_session_write_concern
from pymongo.synchronous.helpers import _handle_reauth
from pymongo.synchronous.helpers import _getaddrinfo, _handle_reauth
from pymongo.synchronous.network import command, receive_message

if TYPE_CHECKING:
Expand Down Expand Up @@ -812,7 +812,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
family = socket.AF_UNSPEC

err = None
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
for res in _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
Expand Down
4 changes: 2 additions & 2 deletions test/asynchronous/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,10 @@ async def test_gssapi_threaded(self):
async def test_gssapi_canonicalize_host_name(self):
# Test the low level method.
assert GSSAPI_HOST is not None
result = _canonicalize_hostname(GSSAPI_HOST, "forward")
result = await _canonicalize_hostname(GSSAPI_HOST, "forward")
if "compute-1.amazonaws.com" not in result:
self.assertEqual(result, GSSAPI_HOST)
result = _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse")
result = await _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse")
self.assertEqual(result, GSSAPI_HOST)

# Use the equivalent named CANONICALIZE_HOST_NAME.
Expand Down
Loading