Skip to content
4 changes: 4 additions & 0 deletions pymongo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Python driver for MongoDB."""
from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor
from typing import ContextManager, Optional

__all__ = [
Expand Down Expand Up @@ -166,3 +167,6 @@ def timeout(seconds: Optional[float]) -> ContextManager[None]:
if seconds is not None:
seconds = float(seconds)
return _csot._TimeoutContext(seconds)


_PYMONGO_EXECUTOR = ThreadPoolExecutor(thread_name_prefix="PYMONGO_EXECUTOR-")
22 changes: 22 additions & 0 deletions pymongo/_asyncio_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A separate ThreadPoolExecutor instance used internally to avoid competing for resources with the default asyncio ThreadPoolExecutor
that user code will use."""

from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor

_PYMONGO_EXECUTOR = ThreadPoolExecutor(thread_name_prefix="PYMONGO_EXECUTOR-")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I like this approach because we now have a thread pool that hangs around forever even after all clients have been closed.

My other comment was more around adding guidance for potential errors, not for changing our implementation. Like something that says "if your app runs into "XXX" error consider this may mean your app's default loop executor is under provisioned. Consider increasing the size of this thread pool or ..."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be difficult to distinguish when this issue occurs, having a separate thread pool for our internal use will help mitigate how common it is. An extra thread pool instance shouldn't be expensive to have a reference to for the lifetime of the application.

Copy link
Member

@ShaneHarvey ShaneHarvey Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I prefer we go with the loop.getaddrinfo approach because it avoids adding the complexity of managing our own thread pool. It's not really kosher to leave a threadpool open even if the threads are "idle". The limitation in loop.getaddrinfo is also implementation detail that could be fixed at any point (even in a python bugfix release).

I expect it will be clear when this issue occurs because a timeout error caused by threadpool starvation looks different than a real DNS timeout error. It should be simple to add an example to our docs by:

  1. saturating the executor with long running tasks
  2. then attempting run a client command
  3. record the error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see much complexity in managing our own thread pool, but I totally understand the desire to not have an extra pool lying around. I'll revert back to using loop.getaddrinfo() once I have a good example for our docs.

Copy link
Contributor Author

@NoahStapp NoahStapp Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After investigating, I believe the docs are slightly misleading: what actually happens when the executor pool is fully saturated is any loop.getaddrinfo() calls block until a thread is freed up for use. There's no timeout mechanism inherent to the executor pool. We could add our own timeout to every loop.run_in_executor() call to prevent users from accidentally blocking the driver forever if they saturate the default executor permanently, but then we would cause timeouts to occur whenever the response is too slow.

If we don't add any timeouts to those calls, users will experience slowdowns whenever they perform a driver operation while the default executor pool is fully saturated. That's preferable to spurious timeouts in my opinion, especially when the user's own code is what determines the frequency of the timeouts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating, I agree with that. The cpython ticket referencing anyio so that could explain the difference.

18 changes: 13 additions & 5 deletions pymongo/asynchronous/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_authenticate_oidc,
_get_authenticator,
)
from pymongo.asynchronous.helpers import getaddrinfo
from pymongo.auth_shared import (
MongoCredential,
_authenticate_scram_start,
Expand Down Expand Up @@ -177,15 +178,22 @@ def _auth_key(nonce: str, username: str, password: str) -> str:
return md5hash.hexdigest()


def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
async def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
"""Canonicalize hostname following MIT-krb5 behavior."""
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
if option in [False, "none"]:
return hostname

af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
)[0]
af, socktype, proto, canonname, sockaddr = (
await getaddrinfo(
hostname,
None,
family=0,
type=0,
proto=socket.IPPROTO_TCP,
flags=socket.AI_CANONNAME,
)
)[0] # type: ignore[index]

# For forward just to resolve the cname as dns.lookup() will not return it.
if option == "forward":
Expand Down Expand Up @@ -213,7 +221,7 @@ async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnecti
# Starting here and continuing through the while loop below - establish
# the security context. See RFC 4752, Section 3.1, first paragraph.
host = props.service_host or conn.address[0]
host = _canonicalize_hostname(host, props.canonicalize_host_name)
host = await _canonicalize_hostname(host, props.canonicalize_host_name)
service = props.service_name + "@" + host
if props.service_realm is not None:
service = service + "@" + props.service_realm
Expand Down
12 changes: 12 additions & 0 deletions pymongo/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""Miscellaneous pieces that need to be synchronized."""
from __future__ import annotations

import asyncio
import builtins
import socket
import sys
from typing import (
Any,
Expand Down Expand Up @@ -68,6 +70,16 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
return cast(F, inner)


async def getaddrinfo(host, port, **kwargs):
if not _IS_SYNC:
loop = asyncio.get_running_loop()
return await loop.getaddrinfo( # type: ignore[assignment]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be using run_in_executor here instead as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch sorry, juggling too many changes at once 😅

host, port, **kwargs
)
else:
return socket.getaddrinfo(host, port, **kwargs) # type: ignore[assignment]


if sys.version_info >= (3, 10):
anext = builtins.anext
aiter = builtins.aiter
Expand Down
15 changes: 9 additions & 6 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@

from bson import DEFAULT_CODEC_OPTIONS
from pymongo import _csot, helpers_shared
from pymongo._asyncio_executor import _PYMONGO_EXECUTOR
from pymongo.asynchronous.client_session import _validate_session_write_concern
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.asynchronous.helpers import _handle_reauth, getaddrinfo
from pymongo.asynchronous.network import command, receive_message
from pymongo.common import (
MAX_BSON_SIZE,
Expand Down Expand Up @@ -783,7 +784,7 @@ def __repr__(self) -> str:
)


def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
"""Given (host, port) and PoolOptions, connect and return a socket object.

Can raise socket.error.
Expand Down Expand Up @@ -814,7 +815,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
family = socket.AF_UNSPEC

err = None
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
for res in await getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
Expand Down Expand Up @@ -863,7 +864,7 @@ async def _configured_socket(

Sets socket's SSL and timeout options.
"""
sock = _create_connection(address, options)
sock = await _create_connection(address, options)
ssl_context = options._ssl_context

if ssl_context is None:
Expand All @@ -883,7 +884,7 @@ async def _configured_socket(
else:
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(
None,
_PYMONGO_EXECUTOR,
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc]
)
else:
Expand All @@ -894,7 +895,9 @@ async def _configured_socket(
ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
else:
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc]
ssl_sock = await loop.run_in_executor(
_PYMONGO_EXECUTOR, ssl_context.wrap_socket, sock
) # type: ignore[assignment, misc]
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
Expand Down
5 changes: 3 additions & 2 deletions pymongo/pyopenssl_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from OpenSSL import SSL as _SSL
from OpenSSL import crypto as _crypto

from pymongo._asyncio_executor import _PYMONGO_EXECUTOR
from pymongo.errors import ConfigurationError as _ConfigurationError
from pymongo.errors import _CertificateError # type:ignore[attr-defined]
from pymongo.ocsp_cache import _OCSPCache
Expand Down Expand Up @@ -405,15 +406,15 @@ async def a_wrap_socket(
ssl_conn.set_tlsext_host_name(server_hostname.encode("idna"))
if self.verify_mode != _stdlibssl.CERT_NONE:
# Request a stapled OCSP response.
await loop.run_in_executor(None, ssl_conn.request_ocsp)
await loop.run_in_executor(_PYMONGO_EXECUTOR, ssl_conn.request_ocsp)
ssl_conn.set_connect_state()
# If this wasn't true the caller of wrap_socket would call
# do_handshake()
if do_handshake_on_connect:
# XXX: If we do hostname checking in a callback we can get rid
# of this call to do_handshake() since the handshake
# will happen automatically later.
await loop.run_in_executor(None, ssl_conn.do_handshake)
await loop.run_in_executor(_PYMONGO_EXECUTOR, ssl_conn.do_handshake)
# XXX: Do this in a callback registered with
# SSLContext.set_info_callback? See Twisted for an example.
if self.check_hostname and server_hostname is not None:
Expand Down
14 changes: 11 additions & 3 deletions pymongo/synchronous/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_authenticate_oidc,
_get_authenticator,
)
from pymongo.synchronous.helpers import getaddrinfo

if TYPE_CHECKING:
from pymongo.hello import Hello
Expand Down Expand Up @@ -180,9 +181,16 @@ def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
if option in [False, "none"]:
return hostname

af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
)[0]
af, socktype, proto, canonname, sockaddr = (
getaddrinfo(
hostname,
None,
family=0,
type=0,
proto=socket.IPPROTO_TCP,
flags=socket.AI_CANONNAME,
)
)[0] # type: ignore[index]

# For forward just to resolve the cname as dns.lookup() will not return it.
if option == "forward":
Expand Down
12 changes: 12 additions & 0 deletions pymongo/synchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""Miscellaneous pieces that need to be synchronized."""
from __future__ import annotations

import asyncio
import builtins
import socket
import sys
from typing import (
Any,
Expand Down Expand Up @@ -68,6 +70,16 @@ def inner(*args: Any, **kwargs: Any) -> Any:
return cast(F, inner)


def getaddrinfo(host, port, **kwargs):
if not _IS_SYNC:
loop = asyncio.get_running_loop()
return loop.getaddrinfo( # type: ignore[assignment]
host, port, **kwargs
)
else:
return socket.getaddrinfo(host, port, **kwargs) # type: ignore[assignment]


if sys.version_info >= (3, 10):
next = builtins.next
iter = builtins.iter
Expand Down
11 changes: 7 additions & 4 deletions pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

from bson import DEFAULT_CODEC_OPTIONS
from pymongo import _csot, helpers_shared
from pymongo._asyncio_executor import _PYMONGO_EXECUTOR
from pymongo.common import (
MAX_BSON_SIZE,
MAX_MESSAGE_SIZE,
Expand Down Expand Up @@ -84,7 +85,7 @@
from pymongo.socket_checker import SocketChecker
from pymongo.ssl_support import HAS_SNI, SSLError
from pymongo.synchronous.client_session import _validate_session_write_concern
from pymongo.synchronous.helpers import _handle_reauth
from pymongo.synchronous.helpers import _handle_reauth, getaddrinfo
from pymongo.synchronous.network import command, receive_message

if TYPE_CHECKING:
Expand Down Expand Up @@ -812,7 +813,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
family = socket.AF_UNSPEC

err = None
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
for res in getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
Expand Down Expand Up @@ -879,7 +880,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
else:
loop = asyncio.get_running_loop()
ssl_sock = loop.run_in_executor(
None,
_PYMONGO_EXECUTOR,
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc]
)
else:
Expand All @@ -890,7 +891,9 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
ssl_sock = ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
else:
loop = asyncio.get_running_loop()
ssl_sock = loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc]
ssl_sock = loop.run_in_executor(
_PYMONGO_EXECUTOR, ssl_context.wrap_socket, sock
) # type: ignore[assignment, misc]
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
Expand Down
4 changes: 2 additions & 2 deletions test/asynchronous/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,10 @@ async def test_gssapi_threaded(self):
async def test_gssapi_canonicalize_host_name(self):
# Test the low level method.
assert GSSAPI_HOST is not None
result = _canonicalize_hostname(GSSAPI_HOST, "forward")
result = await _canonicalize_hostname(GSSAPI_HOST, "forward")
if "compute-1.amazonaws.com" not in result:
self.assertEqual(result, GSSAPI_HOST)
result = _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse")
result = await _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse")
self.assertEqual(result, GSSAPI_HOST)

# Use the equivalent named CANONICALIZE_HOST_NAME.
Expand Down
Loading