Skip to content
24 changes: 19 additions & 5 deletions pymongo/asynchronous/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Authentication helpers."""
from __future__ import annotations

import asyncio
import functools
import hashlib
import hmac
Expand Down Expand Up @@ -177,15 +178,28 @@ def _auth_key(nonce: str, username: str, password: str) -> str:
return md5hash.hexdigest()


def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
async def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
"""Canonicalize hostname following MIT-krb5 behavior."""
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
if option in [False, "none"]:
return hostname

af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
)[0]
if not _IS_SYNC:
loop = asyncio.get_event_loop()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be get_running_loop?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, "Because this function has rather complex behavior (especially when custom event loop policies are in use), using the get_running_loop() function is preferred to get_event_loop() in coroutines and callbacks.".

https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.get_event_loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good find--we use asyncio.get_event_loop() elsewhere in the code, I'll open a separate ticket to change those uses as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed for the rest of the codebase in #2063.

af, socktype, proto, canonname, sockaddr = (
await loop.getaddrinfo(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI:

Note Both getaddrinfo and getnameinfo internally utilize their synchronous versions through the loop’s default thread pool executor. When this executor is saturated, these methods may experience delays, which higher-level networking libraries may report as increased timeouts. To mitigate this, consider using a custom executor for other user tasks, or setting a default executor with a larger number of workers.

https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.getaddrinfo

Which means our users will eventually run into this issue: python/cpython#112169

This is still better than blocking the loop of course but I wonder if we need to warn of this potential problem or if we should test it explicitly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would explicitly testing it provide an actionable solution? We could increase the default number of workers to help mitigate this, but warning users might only confuse them since this is an internal API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we instead use run_in_executor and have our own executor? We use run_in_executor in _configured_socket as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our current uses of run_in_executor also utilize the default thread pool executor. We could configure the executor to have a higher default number of workers, but we'd still hit the same issue depending on the system's resource limits.

Copy link
Member

@blink1073 blink1073 Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that user's code will default to the default executor, so we'd be contending with its resources. We'd be essentially taking Guido's advice and applying it to a library so it doesn't interfere with a default user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you're saying we use our own ThreadPoolExecutor instance to avoid competing with the default executor? Does that make a difference when the underlying OS threads are still shared between the executors?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does if the OS thread limit is much higher than the default thread executor's thread limit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I like that idea then.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also suggest making a utility function for getaddrinfo to avoid repeating this ugly block everywhere. ;)

hostname,
None,
family=0,
type=0,
proto=socket.IPPROTO_TCP,
flags=socket.AI_CANONNAME,
)
)[0] # type: ignore[index]
else:
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
)[0]

# For forward just to resolve the cname as dns.lookup() will not return it.
if option == "forward":
Expand Down Expand Up @@ -213,7 +227,7 @@ async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnecti
# Starting here and continuing through the while loop below - establish
# the security context. See RFC 4752, Section 3.1, first paragraph.
host = props.service_host or conn.address[0]
host = _canonicalize_hostname(host, props.canonicalize_host_name)
host = await _canonicalize_hostname(host, props.canonicalize_host_name)
service = props.service_name + "@" + host
if props.service_realm is not None:
service = service + "@" + props.service_realm
Expand Down
13 changes: 10 additions & 3 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def __repr__(self) -> str:
)


def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
"""Given (host, port) and PoolOptions, connect and return a socket object.

Can raise socket.error.
Expand Down Expand Up @@ -814,7 +814,14 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
family = socket.AF_UNSPEC

err = None
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
if not _IS_SYNC:
loop = asyncio.get_event_loop()
results = await loop.getaddrinfo( # type: ignore[assignment]
host, port, family=family, type=socket.SOCK_STREAM
)
else:
results = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM) # type: ignore[assignment]
for res in results: # type: ignore[attr-defined]
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
Expand Down Expand Up @@ -863,7 +870,7 @@ async def _configured_socket(

Sets socket's SSL and timeout options.
"""
sock = _create_connection(address, options)
sock = await _create_connection(address, options)
ssl_context = options._ssl_context

if ssl_context is None:
Expand Down
20 changes: 17 additions & 3 deletions pymongo/synchronous/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Authentication helpers."""
from __future__ import annotations

import asyncio
import functools
import hashlib
import hmac
Expand Down Expand Up @@ -180,9 +181,22 @@ def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
if option in [False, "none"]:
return hostname

af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
)[0]
if not _IS_SYNC:
loop = asyncio.get_event_loop()
af, socktype, proto, canonname, sockaddr = (
loop.getaddrinfo(
hostname,
None,
family=0,
type=0,
proto=socket.IPPROTO_TCP,
flags=socket.AI_CANONNAME,
)
)[0] # type: ignore[index]
else:
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
)[0]

# For forward just to resolve the cname as dns.lookup() will not return it.
if option == "forward":
Expand Down
9 changes: 8 additions & 1 deletion pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,14 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
family = socket.AF_UNSPEC

err = None
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
if not _IS_SYNC:
loop = asyncio.get_event_loop()
results = loop.getaddrinfo( # type: ignore[assignment]
host, port, family=family, type=socket.SOCK_STREAM
)
else:
results = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM) # type: ignore[assignment]
for res in results: # type: ignore[attr-defined]
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
Expand Down
4 changes: 2 additions & 2 deletions test/asynchronous/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,10 @@ async def test_gssapi_threaded(self):
async def test_gssapi_canonicalize_host_name(self):
# Test the low level method.
assert GSSAPI_HOST is not None
result = _canonicalize_hostname(GSSAPI_HOST, "forward")
result = await _canonicalize_hostname(GSSAPI_HOST, "forward")
if "compute-1.amazonaws.com" not in result:
self.assertEqual(result, GSSAPI_HOST)
result = _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse")
result = await _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse")
self.assertEqual(result, GSSAPI_HOST)

# Use the equivalent named CANONICALIZE_HOST_NAME.
Expand Down
Loading