diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 68de42db84..71a694a619 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -64,11 +64,6 @@ from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.asynchronous.pool import ( - _configured_socket, - _get_timeout_details, - _raise_connection_failure, -) from pymongo.common import CONNECT_TIMEOUT from pymongo.daemon import _spawn_daemon from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts @@ -80,12 +75,17 @@ NetworkTimeout, ServerSelectionTimeoutError, ) -from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall +from pymongo.network_layer import async_socket_sendall from pymongo.operations import UpdateOne from pymongo.pool_options import PoolOptions +from pymongo.pool_shared import ( + _async_configured_socket, + _get_timeout_details, + _raise_connection_failure, +) from pymongo.read_concern import ReadConcern from pymongo.results import BulkWriteResult, DeleteResult -from pymongo.ssl_support import get_ssl_context +from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context from pymongo.typings import _DocumentType, _DocumentTypeArg from pymongo.uri_parser_shared import parse_host from pymongo.write_concern import WriteConcern @@ -113,7 +113,7 @@ async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]: try: - return await _configured_socket(address, opts) + return await _async_configured_socket(address, opts) except Exception as exc: _raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts)) @@ -196,7 +196,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: try: conn = await _connect_kms(address, opts) try: - await async_sendall(conn, message) + await async_socket_sendall(conn, message) while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 754b8325ed..5c763c2894 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2078,7 +2078,7 @@ async def _cleanup_cursor_lock( # exhausted the result set we *must* close the socket # to stop the server from sending more data. assert conn_mgr.conn is not None - conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR) + await conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR) else: await self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr) if conn_mgr: diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 1b0799e1c4..479ca1a314 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -36,7 +36,11 @@ from pymongo.server_description import ServerDescription if TYPE_CHECKING: - from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext + from pymongo.asynchronous.pool import ( # type: ignore[attr-defined] + AsyncConnection, + Pool, + _CancellationContext, + ) from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index e529a52ee9..5f14bef45d 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -17,7 +17,6 @@ import datetime import logging -import time from typing import ( TYPE_CHECKING, Any, @@ -31,20 +30,16 @@ from bson import _decode_all_selective from pymongo import _csot, helpers_shared, message -from pymongo.common import MAX_MESSAGE_SIZE -from pymongo.compression_support import _NO_COMPRESSION, decompress +from pymongo.compression_support import _NO_COMPRESSION from pymongo.errors import ( NotPrimaryError, OperationFailure, - ProtocolError, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply +from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( - _UNPACK_COMPRESSION_HEADER, - _UNPACK_HEADER, - async_receive_data, + async_receive_message, async_sendall, ) @@ -194,13 +189,13 @@ async def command( ) try: - await async_sendall(conn.conn, msg) + await async_sendall(conn.conn.get_conn, msg) if use_op_msg and unacknowledged: # Unacknowledged, fake a successful command response. reply = None response_doc: _DocumentOut = {"ok": 1} else: - reply = await receive_message(conn, request_id) + reply = await async_receive_message(conn, request_id) conn.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response( codec_options=codec_options, user_fields=user_fields @@ -301,47 +296,3 @@ async def command( ) return response_doc # type: ignore[return-value] - - -async def receive_message( - conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE -) -> Union[_OpReply, _OpMsg]: - """Receive a raw BSON message or raise socket.error.""" - if _csot.get_timeout(): - deadline = _csot.get_deadline() - else: - timeout = conn.conn.gettimeout() - if timeout: - deadline = time.monotonic() + timeout - else: - deadline = None - # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline)) - # No request_id for exhaust cursor "getMore". - if request_id is not None: - if request_id != response_to: - raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}") - if length <= 16: - raise ProtocolError( - f"Message length ({length!r}) not longer than standard message header size (16)" - ) - if length > max_message_size: - raise ProtocolError( - f"Message length ({length!r}) is larger than server max " - f"message size ({max_message_size!r})" - ) - if op_code == 2012: - op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - await async_receive_data(conn, 9, deadline) - ) - data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id) - else: - data = await async_receive_data(conn, length - 16, deadline) - - try: - unpack_reply = _UNPACK_REPLY[op_code] - except KeyError: - raise ProtocolError( - f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" - ) from None - return unpack_reply(data) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index d06c528e78..6ebdb5cb20 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -14,14 +14,10 @@ from __future__ import annotations -import asyncio import collections import contextlib -import functools import logging import os -import socket -import ssl import sys import time import weakref @@ -40,8 +36,8 @@ from bson import DEFAULT_CODEC_OPTIONS from pymongo import _csot, helpers_shared from pymongo.asynchronous.client_session import _validate_session_write_concern -from pymongo.asynchronous.helpers import _getaddrinfo, _handle_reauth -from pymongo.asynchronous.network import command, receive_message +from pymongo.asynchronous.helpers import _handle_reauth +from pymongo.asynchronous.network import command from pymongo.common import ( MAX_BSON_SIZE, MAX_MESSAGE_SIZE, @@ -52,16 +48,13 @@ from pymongo.errors import ( # type:ignore[attr-defined] AutoReconnect, ConfigurationError, - ConnectionFailure, DocumentTooLarge, ExecutionTimeout, InvalidOperation, - NetworkTimeout, NotPrimaryError, OperationFailure, PyMongoError, WaitQueueTimeoutError, - _CertificateError, ) from pymongo.hello import Hello, HelloCompat from pymongo.lock import ( @@ -79,13 +72,20 @@ ConnectionCheckOutFailedReason, ConnectionClosedReason, ) -from pymongo.network_layer import async_sendall +from pymongo.network_layer import AsyncNetworkingInterface, async_receive_message, async_sendall from pymongo.pool_options import PoolOptions +from pymongo.pool_shared import ( + _CancellationContext, + _configured_protocol_interface, + _get_timeout_details, + _raise_connection_failure, + format_timeout_details, +) from pymongo.read_preferences import ReadPreference from pymongo.server_api import _add_to_command from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker -from pymongo.ssl_support import HAS_SNI, SSLError +from pymongo.ssl_support import SSLError if TYPE_CHECKING: from bson import CodecOptions @@ -99,7 +99,6 @@ ZstdContext, ) from pymongo.message import _OpMsg, _OpReply - from pymongo.pyopenssl_context import _sslConn from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode from pymongo.typings import _Address, _CollationIn @@ -123,133 +122,6 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001 _IS_SYNC = False -_MAX_TCP_KEEPIDLE = 120 -_MAX_TCP_KEEPINTVL = 10 -_MAX_TCP_KEEPCNT = 9 - -if sys.platform == "win32": - try: - import _winreg as winreg - except ImportError: - import winreg - - def _query(key, name, default): - try: - value, _ = winreg.QueryValueEx(key, name) - # Ensure the value is a number or raise ValueError. - return int(value) - except (OSError, ValueError): - # QueryValueEx raises OSError when the key does not exist (i.e. - # the system is using the Windows default value). - return default - - try: - with winreg.OpenKey( - winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters" - ) as key: - _WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000) - _WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000) - except OSError: - # We could not check the default values because winreg.OpenKey failed. - # Assume the system is using the default values. - _WINDOWS_TCP_IDLE_MS = 7200000 - _WINDOWS_TCP_INTERVAL_MS = 1000 - - def _set_keepalive_times(sock): - idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000) - interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000) - if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS: - sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms)) - -else: - - def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None: - if hasattr(socket, tcp_option): - sockopt = getattr(socket, tcp_option) - try: - # PYTHON-1350 - NetBSD doesn't implement getsockopt for - # TCP_KEEPIDLE and friends. Don't attempt to set the - # values there. - default = sock.getsockopt(socket.IPPROTO_TCP, sockopt) - if default > max_value: - sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value) - except OSError: - pass - - def _set_keepalive_times(sock: socket.socket) -> None: - _set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE) - _set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL) - _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) - - -def _raise_connection_failure( - address: Any, - error: Exception, - msg_prefix: Optional[str] = None, - timeout_details: Optional[dict[str, float]] = None, -) -> NoReturn: - """Convert a socket.error to ConnectionFailure and raise it.""" - host, port = address - # If connecting to a Unix socket, port will be None. - if port is not None: - msg = "%s:%d: %s" % (host, port, error) - else: - msg = f"{host}: {error}" - if msg_prefix: - msg = msg_prefix + msg - if "configured timeouts" not in msg: - msg += format_timeout_details(timeout_details) - if isinstance(error, socket.timeout): - raise NetworkTimeout(msg) from error - elif isinstance(error, SSLError) and "timed out" in str(error): - # Eventlet does not distinguish TLS network timeouts from other - # SSLErrors (https://github.com/eventlet/eventlet/issues/692). - # Luckily, we can work around this limitation because the phrase - # 'timed out' appears in all the timeout related SSLErrors raised. - raise NetworkTimeout(msg) from error - else: - raise AutoReconnect(msg) from error - - -def _get_timeout_details(options: PoolOptions) -> dict[str, float]: - details = {} - timeout = _csot.get_timeout() - socket_timeout = options.socket_timeout - connect_timeout = options.connect_timeout - if timeout: - details["timeoutMS"] = timeout * 1000 - if socket_timeout and not timeout: - details["socketTimeoutMS"] = socket_timeout * 1000 - if connect_timeout: - details["connectTimeoutMS"] = connect_timeout * 1000 - return details - - -def format_timeout_details(details: Optional[dict[str, float]]) -> str: - result = "" - if details: - result += " (configured timeouts:" - for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]: - if timeout in details: - result += f" {timeout}: {details[timeout]}ms," - result = result[:-1] - result += ")" - return result - - -class _CancellationContext: - def __init__(self) -> None: - self._cancelled = False - - def cancel(self) -> None: - """Cancel this context.""" - self._cancelled = True - - @property - def cancelled(self) -> bool: - """Was cancel called?""" - return self._cancelled - class AsyncConnection: """Store a connection with some metadata. @@ -261,7 +133,11 @@ class AsyncConnection: """ def __init__( - self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int + self, + conn: AsyncNetworkingInterface, + pool: Pool, + address: tuple[str, int], + id: int, ): self.pool_ref = weakref.ref(pool) self.conn = conn @@ -318,7 +194,7 @@ def set_conn_timeout(self, timeout: Optional[float]) -> None: if timeout == self.last_timeout: return self.last_timeout = timeout - self.conn.settimeout(timeout) + self.conn.get_conn.settimeout(timeout) def apply_timeout( self, client: AsyncMongoClient, cmd: Optional[MutableMapping[str, Any]] @@ -364,7 +240,7 @@ async def unpin(self) -> None: if pool: await pool.checkin(self) else: - self.close_conn(ConnectionClosedReason.STALE) + await self.close_conn(ConnectionClosedReason.STALE) def hello_cmd(self) -> dict[str, Any]: # Handshake spec requires us to use OP_MSG+hello command for the @@ -559,7 +435,7 @@ async def command( raise # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: - self._raise_connection_failure(error) + await self._raise_connection_failure(error) async def send_message(self, message: bytes, max_doc_size: int) -> None: """Send a raw BSON message or raise ConnectionFailure. @@ -573,10 +449,10 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None: ) try: - await async_sendall(self.conn, message) + await async_sendall(self.conn.get_conn, message) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: - self._raise_connection_failure(error) + await self._raise_connection_failure(error) async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise ConnectionFailure. @@ -584,10 +460,10 @@ async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _O If any exception is raised, the socket is closed. """ try: - return await receive_message(self, request_id, self.max_message_size) + return await async_receive_message(self, request_id, self.max_message_size) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: - self._raise_connection_failure(error) + await self._raise_connection_failure(error) def _raise_if_not_writable(self, unacknowledged: bool) -> None: """Raise NotPrimaryError on unacknowledged write if this socket is not @@ -673,11 +549,11 @@ def validate_session( "Can only use session with the AsyncMongoClient that started it" ) - def close_conn(self, reason: Optional[str]) -> None: + async def close_conn(self, reason: Optional[str]) -> None: """Close this connection with a reason.""" if self.closed: return - self._close_conn() + await self._close_conn() if reason: if self.enabled_for_cmap: assert self.listeners is not None @@ -694,7 +570,7 @@ def close_conn(self, reason: Optional[str]) -> None: error=reason, ) - def _close_conn(self) -> None: + async def _close_conn(self) -> None: """Close this connection.""" if self.closed: return @@ -703,13 +579,16 @@ def _close_conn(self) -> None: # Note: We catch exceptions to avoid spurious errors on interpreter # shutdown. try: - self.conn.close() + await self.conn.close() except Exception: # noqa: S110 pass def conn_closed(self) -> bool: """Return True if we know socket has been closed, False otherwise.""" - return self.socket_checker.socket_closed(self.conn) + if _IS_SYNC: + return self.socket_checker.socket_closed(self.conn.get_conn) + else: + return self.conn.is_closing() def send_cluster_time( self, @@ -736,7 +615,7 @@ def idle_time_seconds(self) -> float: """Seconds since this socket was last checked into its pool.""" return time.monotonic() - self.last_checkin_time - def _raise_connection_failure(self, error: BaseException) -> NoReturn: + async def _raise_connection_failure(self, error: BaseException) -> NoReturn: # Catch *all* exceptions from socket methods and close the socket. In # regular Python, socket operations only raise socket.error, even if # the underlying cause was a Ctrl-C: a signal raised during socket.recv @@ -756,7 +635,7 @@ def _raise_connection_failure(self, error: BaseException) -> NoReturn: reason = None else: reason = ConnectionClosedReason.ERROR - self.close_conn(reason) + await self.close_conn(reason) # SSLError from PyOpenSSL inherits directly from Exception. if isinstance(error, (IOError, OSError, SSLError)): details = _get_timeout_details(self.opts) @@ -781,145 +660,6 @@ def __repr__(self) -> str: ) -async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: - """Given (host, port) and PoolOptions, connect and return a socket object. - - Can raise socket.error. - - This is a modified version of create_connection from CPython >= 2.7. - """ - host, port = address - - # Check if dealing with a unix domain socket - if host.endswith(".sock"): - if not hasattr(socket, "AF_UNIX"): - raise ConnectionFailure("UNIX-sockets are not supported on this system") - sock = socket.socket(socket.AF_UNIX) - # SOCK_CLOEXEC not supported for Unix sockets. - _set_non_inheritable_non_atomic(sock.fileno()) - try: - sock.connect(host) - return sock - except OSError: - sock.close() - raise - - # Don't try IPv6 if we don't support it. Also skip it if host - # is 'localhost' (::1 is fine). Avoids slow connect issues - # like PYTHON-356. - family = socket.AF_INET - if socket.has_ipv6 and host != "localhost": - family = socket.AF_UNSPEC - - err = None - for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] - af, socktype, proto, dummy, sa = res - # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited - # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 - # all file descriptors are created non-inheritable. See PEP 446. - try: - sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto) - except OSError: - # Can SOCK_CLOEXEC be defined even if the kernel doesn't support - # it? - sock = socket.socket(af, socktype, proto) - # Fallback when SOCK_CLOEXEC isn't available. - _set_non_inheritable_non_atomic(sock.fileno()) - try: - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - # CSOT: apply timeout to socket connect. - timeout = _csot.remaining() - if timeout is None: - timeout = options.connect_timeout - elif timeout <= 0: - raise socket.timeout("timed out") - sock.settimeout(timeout) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) - _set_keepalive_times(sock) - sock.connect(sa) - return sock - except OSError as e: - err = e - sock.close() - - if err is not None: - raise err - else: - # This likely means we tried to connect to an IPv6 only - # host with an OS/kernel or Python interpreter that doesn't - # support IPv6. The test case is Jython2.5.1 which doesn't - # support IPv6 at all. - raise OSError("getaddrinfo failed") - - -async def _configured_socket( - address: _Address, options: PoolOptions -) -> Union[socket.socket, _sslConn]: - """Given (host, port) and PoolOptions, return a configured socket. - - Can raise socket.error, ConnectionFailure, or _CertificateError. - - Sets socket's SSL and timeout options. - """ - sock = await _create_connection(address, options) - ssl_context = options._ssl_context - - if ssl_context is None: - sock.settimeout(options.socket_timeout) - return sock - - host = address[0] - try: - # We have to pass hostname / ip address to wrap_socket - # to use SSLContext.check_hostname. - if HAS_SNI: - if _IS_SYNC: - ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) - else: - if hasattr(ssl_context, "a_wrap_socket"): - ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc] - else: - loop = asyncio.get_running_loop() - ssl_sock = await loop.run_in_executor( - None, - functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc] - ) - else: - if _IS_SYNC: - ssl_sock = ssl_context.wrap_socket(sock) - else: - if hasattr(ssl_context, "a_wrap_socket"): - ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc] - else: - loop = asyncio.get_running_loop() - ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc] - except _CertificateError: - sock.close() - # Raise _CertificateError directly like we do after match_hostname - # below. - raise - except (OSError, SSLError) as exc: - sock.close() - # We raise AutoReconnect for transient and permanent SSL handshake - # failures alike. Permanent handshake failures, like protocol - # mismatch, will be turned into ServerSelectionTimeoutErrors later. - details = _get_timeout_details(options) - _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: - ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined] - except _CertificateError: - ssl_sock.close() - raise - - ssl_sock.settimeout(options.socket_timeout) - return ssl_sock - - class _PoolClosedError(PyMongoError): """Internal error raised when a thread tries to get a connection from a closed pool. @@ -1121,7 +861,7 @@ async def _reset( # publishing the PoolClearedEvent. if close: for conn in sockets: - conn.close_conn(ConnectionClosedReason.POOL_CLOSED) + await conn.close_conn(ConnectionClosedReason.POOL_CLOSED) if self.enabled_for_cmap: assert listeners is not None listeners.publish_pool_closed(self.address) @@ -1152,7 +892,7 @@ async def _reset( serviceId=service_id, ) for conn in sockets: - conn.close_conn(ConnectionClosedReason.STALE) + await conn.close_conn(ConnectionClosedReason.STALE) async def update_is_writable(self, is_writable: Optional[bool]) -> None: """Updates the is_writable attribute on all sockets currently in the @@ -1197,7 +937,7 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds ): conn = self.conns.pop() - conn.close_conn(ConnectionClosedReason.IDLE) + await conn.close_conn(ConnectionClosedReason.IDLE) while True: async with self.size_cond: @@ -1221,7 +961,7 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. if self.gen.get_overall() != reference_generation: - conn.close_conn(ConnectionClosedReason.STALE) + await conn.close_conn(ConnectionClosedReason.STALE) return self.conns.appendleft(conn) self.active_contexts.discard(conn.cancel_context) @@ -1266,7 +1006,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A ) try: - sock = await _configured_socket(self.address, self.opts) + networking_interface = await _configured_protocol_interface(self.address, self.opts) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: async with self.lock: @@ -1293,7 +1033,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A raise - conn = AsyncConnection(sock, self, self.address, conn_id) # type: ignore[arg-type] + conn = AsyncConnection(networking_interface, self, self.address, conn_id) # type: ignore[arg-type] async with self.lock: self.active_contexts.add(conn.cancel_context) self.active_contexts.discard(tmp_context) @@ -1311,7 +1051,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A except BaseException: async with self.lock: self.active_contexts.discard(conn.cancel_context) - conn.close_conn(ConnectionClosedReason.ERROR) + await conn.close_conn(ConnectionClosedReason.ERROR) raise if handler: @@ -1509,7 +1249,7 @@ async def _get_conn( except IndexError: self._pending += 1 if conn: # We got a socket from the pool - if self._perished(conn): + if await self._perished(conn): conn = None continue else: # We need to create a new connection @@ -1523,7 +1263,7 @@ async def _get_conn( except BaseException: if conn: # We checked out a socket but authentication failed. - conn.close_conn(ConnectionClosedReason.ERROR) + await conn.close_conn(ConnectionClosedReason.ERROR) async with self.size_cond: self.requests -= 1 if incremented: @@ -1583,7 +1323,7 @@ async def checkin(self, conn: AsyncConnection) -> None: await self.reset_without_pause() else: if self.closed: - conn.close_conn(ConnectionClosedReason.POOL_CLOSED) + await conn.close_conn(ConnectionClosedReason.POOL_CLOSED) elif conn.closed: # CMAP requires the closed event be emitted after the check in. if self.enabled_for_cmap: @@ -1607,7 +1347,7 @@ async def checkin(self, conn: AsyncConnection) -> None: # Hold the lock to ensure this section does not race with # Pool.reset(). if self.stale_generation(conn.generation, conn.service_id): - conn.close_conn(ConnectionClosedReason.STALE) + await conn.close_conn(ConnectionClosedReason.STALE) else: conn.update_last_checkin_time() conn.update_is_writable(bool(self.is_writable)) @@ -1625,7 +1365,7 @@ async def checkin(self, conn: AsyncConnection) -> None: self.operation_count -= 1 self.size_cond.notify() - def _perished(self, conn: AsyncConnection) -> bool: + async def _perished(self, conn: AsyncConnection) -> bool: """Return True and close the connection if it is "perished". This side-effecty function checks if this socket has been idle for @@ -1645,18 +1385,18 @@ def _perished(self, conn: AsyncConnection) -> bool: self.opts.max_idle_time_seconds is not None and idle_time_seconds > self.opts.max_idle_time_seconds ): - conn.close_conn(ConnectionClosedReason.IDLE) + await conn.close_conn(ConnectionClosedReason.IDLE) return True if self._check_interval_seconds is not None and ( self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds ): if conn.conn_closed(): - conn.close_conn(ConnectionClosedReason.ERROR) + await conn.close_conn(ConnectionClosedReason.ERROR) return True if self.stale_generation(conn.generation, conn.service_id): - conn.close_conn(ConnectionClosedReason.STALE) + await conn.close_conn(ConnectionClosedReason.STALE) return True return False @@ -1704,5 +1444,6 @@ def __del__(self) -> None: # Avoid ResourceWarnings in Python 3 # Close all sockets without calling reset() or close() because it is # not safe to acquire a lock in __del__. - for conn in self.conns: - conn.close_conn(None) + if _IS_SYNC: + for conn in self.conns: + conn.close_conn(None) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 4512aba59f..e287655c61 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -16,21 +16,26 @@ from __future__ import annotations import asyncio +import collections import errno import socket import struct import sys import time -from asyncio import AbstractEventLoop, Future +from asyncio import AbstractEventLoop, BaseTransport, BufferedProtocol, Future, Transport from typing import ( TYPE_CHECKING, + Any, Optional, Union, ) from pymongo import _csot, ssl_support from pymongo._asyncio_task import create_task -from pymongo.errors import _OperationCancelled +from pymongo.common import MAX_MESSAGE_SIZE +from pymongo.compression_support import decompress +from pymongo.errors import ProtocolError, _OperationCancelled +from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.socket_checker import _errno_from_exception try: @@ -69,13 +74,15 @@ BLOCKING_IO_ERRORS = (BlockingIOError, BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS) -async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: +# These socket-based I/O methods are for KMS requests and any other network operations that do not use +# the MongoDB wire protocol +async def async_socket_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: timeout = sock.gettimeout() sock.settimeout(0.0) loop = asyncio.get_running_loop() try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - await asyncio.wait_for(_async_sendall_ssl(sock, buf, loop), timeout=timeout) + await asyncio.wait_for(_async_socket_sendall_ssl(sock, buf, loop), timeout=timeout) else: await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type] except asyncio.TimeoutError as exc: @@ -87,7 +94,7 @@ async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> Non if sys.platform != "win32": - async def _async_sendall_ssl( + async def _async_socket_sendall_ssl( sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop ) -> None: view = memoryview(buf) @@ -130,7 +137,7 @@ def _is_ready(fut: Future) -> None: loop.remove_reader(fd) loop.remove_writer(fd) - async def _async_receive_ssl( + async def _async_socket_receive_ssl( conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False ) -> memoryview: mv = memoryview(bytearray(length)) @@ -184,7 +191,7 @@ def _is_ready(fut: Future) -> None: # The default Windows asyncio event loop does not support loop.add_reader/add_writer: # https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support # Note: In PYTHON-4493 we plan to replace this code with asyncio streams. - async def _async_sendall_ssl( + async def _async_socket_sendall_ssl( sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop ) -> None: view = memoryview(buf) @@ -205,7 +212,7 @@ async def _async_sendall_ssl( backoff = min(backoff * 2, 0.512) total_sent += sent - async def _async_receive_ssl( + async def _async_socket_receive_ssl( conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False ) -> memoryview: mv = memoryview(bytearray(length)) @@ -244,52 +251,6 @@ async def _poll_cancellation(conn: AsyncConnection) -> None: await asyncio.sleep(_POLL_TIMEOUT) -async def async_receive_data( - conn: AsyncConnection, length: int, deadline: Optional[float] -) -> memoryview: - sock = conn.conn - sock_timeout = sock.gettimeout() - timeout: Optional[Union[float, int]] - if deadline: - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - timeout = max(deadline - time.monotonic(), 0) - else: - timeout = sock_timeout - - sock.settimeout(0.0) - loop = asyncio.get_running_loop() - cancellation_task = create_task(_poll_cancellation(conn)) - try: - if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] - else: - read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] - tasks = [read_task, cancellation_task] - try: - done, pending = await asyncio.wait( - tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED - ) - for task in pending: - task.cancel() - if pending: - await asyncio.wait(pending) - if len(done) == 0: - raise socket.timeout("timed out") - if read_task in done: - return read_task.result() - raise _OperationCancelled("operation cancelled") - except asyncio.CancelledError: - for task in tasks: - task.cancel() - await asyncio.wait(tasks) - raise - - finally: - sock.settimeout(sock_timeout) - - async def async_receive_data_socket( sock: Union[socket.socket, _sslConn], length: int ) -> memoryview: @@ -301,18 +262,23 @@ async def async_receive_data_socket( try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): return await asyncio.wait_for( - _async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type] + _async_socket_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type] timeout=timeout, ) else: - return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type] + return await asyncio.wait_for( + _async_socket_receive(sock, length, loop), # type: ignore[arg-type] + timeout=timeout, + ) except asyncio.TimeoutError as err: raise socket.timeout("timed out") from err finally: sock.settimeout(sock_timeout) -async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview: +async def _async_socket_receive( + conn: socket.socket, length: int, loop: AbstractEventLoop +) -> memoryview: mv = memoryview(bytearray(length)) bytes_read = 0 while bytes_read < length: @@ -328,7 +294,7 @@ async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLo def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: """Block until at least one byte is read, or a timeout, or a cancel.""" - sock = conn.conn + sock = conn.conn.sock timed_out = False # Check if the connection's socket has been manually closed if sock.fileno() == -1: @@ -413,3 +379,403 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me conn.set_conn_timeout(orig_timeout) return mv + + +class NetworkingInterfaceBase: + def __init__(self, conn: Any): + self.conn = conn + + @property + def gettimeout(self) -> Any: + raise NotImplementedError + + def settimeout(self, timeout: float | None) -> None: + raise NotImplementedError + + def close(self) -> Any: + raise NotImplementedError + + def is_closing(self) -> bool: + raise NotImplementedError + + @property + def get_conn(self) -> Any: + raise NotImplementedError + + @property + def sock(self) -> Any: + raise NotImplementedError + + +class AsyncNetworkingInterface(NetworkingInterfaceBase): + def __init__(self, conn: tuple[Transport, PyMongoProtocol]): + super().__init__(conn) + + @property + def gettimeout(self) -> float | None: + return self.conn[1].gettimeout + + def settimeout(self, timeout: float | None) -> None: + self.conn[1].settimeout(timeout) + + async def close(self) -> None: + self.conn[1].close() + await self.conn[1].wait_closed() + + def is_closing(self) -> bool: + return self.conn[0].is_closing() + + @property + def get_conn(self) -> PyMongoProtocol: + return self.conn[1] + + @property + def sock(self) -> socket.socket: + return self.conn[0].get_extra_info("socket") + + +class NetworkingInterface(NetworkingInterfaceBase): + def __init__(self, conn: Union[socket.socket, _sslConn]): + super().__init__(conn) + + def gettimeout(self) -> float | None: + return self.conn.gettimeout() + + def settimeout(self, timeout: float | None) -> None: + self.conn.settimeout(timeout) + + def close(self) -> None: + self.conn.close() + + def is_closing(self) -> bool: + return self.conn.is_closing() + + @property + def get_conn(self) -> Union[socket.socket, _sslConn]: + return self.conn + + @property + def sock(self) -> Union[socket.socket, _sslConn]: + return self.conn + + def fileno(self) -> int: + return self.conn.fileno() + + def recv_into(self, buffer: bytes) -> int: + return self.conn.recv_into(buffer) + + +class PyMongoProtocol(BufferedProtocol): + def __init__(self, timeout: Optional[float] = None): + self.transport: Transport = None # type: ignore[assignment] + # Each message is reader in 2-3 parts: header, compression header, and message body + # The message buffer is allocated after the header is read. + self._header = memoryview(bytearray(16)) + self._header_index = 0 + self._compression_header = memoryview(bytearray(9)) + self._compression_index = 0 + self._message: Optional[memoryview] = None + self._message_index = 0 + # State. TODO: replace booleans with an enum? + self._expecting_header = True + self._expecting_compression = False + self._message_size = 0 + self._op_code = 0 + self._connection_lost = False + self._read_waiter: Optional[Future] = None + self._timeout = timeout + self._is_compressed = False + self._compressor_id: Optional[int] = None + self._max_message_size = MAX_MESSAGE_SIZE + self._response_to: Optional[int] = None + self._closed = asyncio.get_running_loop().create_future() + self._pending_messages: collections.deque[Future] = collections.deque() + self._done_messages: collections.deque[Future] = collections.deque() + + def settimeout(self, timeout: float | None) -> None: + self._timeout = timeout + + @property + def gettimeout(self) -> float | None: + """The configured timeout for the socket that underlies our protocol pair.""" + return self._timeout + + def connection_made(self, transport: BaseTransport) -> None: + """Called exactly once when a connection is made. + The transport argument is the transport representing the write side of the connection. + """ + self.transport = transport # type: ignore[assignment] + self.transport.set_write_buffer_limits(MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE) + + async def write(self, message: bytes) -> None: + """Write a message to this connection's transport.""" + if self.transport.is_closing(): + raise OSError("Connection is closed") + self.transport.write(message) + self.transport.resume_reading() + + async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]: + """Read a single MongoDB Wire Protocol message from this connection.""" + if self.transport: + try: + self.transport.resume_reading() + # Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322 + except AttributeError: + raise OSError("connection is already closed") from None + self._max_message_size = max_message_size + if self._done_messages: + message = await self._done_messages.popleft() + else: + if self.transport and self.transport.is_closing(): + raise OSError("connection is already closed") + read_waiter = asyncio.get_running_loop().create_future() + self._pending_messages.append(read_waiter) + try: + message = await read_waiter + finally: + if read_waiter in self._done_messages: + self._done_messages.remove(read_waiter) + if message: + op_code, compressor_id, response_to, data = message + # No request_id for exhaust cursor "getMore". + if request_id is not None: + if request_id != response_to: + raise ProtocolError( + f"Got response id {response_to!r} but expected {request_id!r}" + ) + if compressor_id is not None: + data = decompress(data, compressor_id) + return data, op_code + raise OSError("connection closed") + + def get_buffer(self, sizehint: int) -> memoryview: + """Called to allocate a new receive buffer. + The asyncio loop calls this method expecting to receive a non-empty buffer to fill with data. + If any data does not fit into the returned buffer, this method will be called again until + either no data remains or an empty buffer is returned. + """ + # Due to a bug, Python <=3.11 will call get_buffer() even after we raise + # ProtocolError in buffer_updated() and call connection_lost(). We allocate + # a temp buffer to drain the waiting data. + if self._connection_lost: + if not self._message: + self._message = memoryview(bytearray(2**14)) + return self._message + # TODO: optimize this by caching pointers to the buffers. + # return self._buffer[self._index:] + if self._expecting_header: + return self._header[self._header_index :] + if self._expecting_compression: + return self._compression_header[self._compression_index :] + return self._message[self._message_index :] # type: ignore[index] + + def buffer_updated(self, nbytes: int) -> None: + """Called when the buffer was updated with the received data""" + # Wrote 0 bytes into a non-empty buffer, signal connection closed + if nbytes == 0: + self.close(OSError("connection closed")) + return + if self._connection_lost: + return + if self._expecting_header: + self._header_index += nbytes + if self._header_index >= 16: + self._expecting_header = False + try: + ( + self._message_size, + self._op_code, + self._response_to, + self._expecting_compression, + ) = self.process_header() + except ProtocolError as exc: + self.close(exc) + return + self._message = memoryview(bytearray(self._message_size)) + return + if self._expecting_compression: + self._compression_index += nbytes + if self._compression_index >= 9: + self._expecting_compression = False + self._op_code, self._compressor_id = self.process_compression_header() + return + + self._message_index += nbytes + if self._message_index >= self._message_size: + self._expecting_header = True + # Pause reading to avoid storing an arbitrary number of messages in memory. + self.transport.pause_reading() + if self._pending_messages: + result = self._pending_messages.popleft() + else: + result = asyncio.get_running_loop().create_future() + # Future has been cancelled, close this connection + if result.done(): + self.close(None) + return + # Necessary values to reconstruct and verify message + result.set_result( + (self._op_code, self._compressor_id, self._response_to, self._message) + ) + self._done_messages.append(result) + # Reset internal state to expect a new message + self._header_index = 0 + self._compression_index = 0 + self._message_index = 0 + self._message_size = 0 + self._message = None + self._op_code = 0 + self._compressor_id = None + self._response_to = None + + def process_header(self) -> tuple[int, int, int, bool]: + """Unpack a MongoDB Wire Protocol header.""" + length, _, response_to, op_code = _UNPACK_HEADER(self._header) + expecting_compression = False + if op_code == 2012: # OP_COMPRESSED + if length <= 25: + raise ProtocolError( + f"Message length ({length!r}) not longer than standard OP_COMPRESSED message header size (25)" + ) + expecting_compression = True + length -= 9 + if length <= 16: + raise ProtocolError( + f"Message length ({length!r}) not longer than standard message header size (16)" + ) + if length > self._max_message_size: + raise ProtocolError( + f"Message length ({length!r}) is larger than server max " + f"message size ({self._max_message_size!r})" + ) + + return length - 16, op_code, response_to, expecting_compression + + def process_compression_header(self) -> tuple[int, int]: + """Unpack a MongoDB Wire Protocol compression header.""" + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header) + return op_code, compressor_id + + def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None: + pending = list(self._pending_messages) + for msg in pending: + if not msg.done(): + if exc is None: + msg.set_result(None) + else: + msg.set_exception(exc) + self._done_messages.append(msg) + + def close(self, exc: Optional[Exception] = None) -> None: + self.transport.abort() + self._resolve_pending_messages(exc) + self._connection_lost = True + + def connection_lost(self, exc: Optional[Exception] = None) -> None: + self._resolve_pending_messages(exc) + if not self._closed.done(): + self._closed.set_result(None) + + async def wait_closed(self) -> None: + await self._closed + + +async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None: + try: + await asyncio.wait_for(conn.write(buf), timeout=conn.gettimeout) + except asyncio.TimeoutError as exc: + # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands. + raise socket.timeout("timed out") from exc + + +async def async_receive_message( + conn: AsyncConnection, + request_id: Optional[int], + max_message_size: int = MAX_MESSAGE_SIZE, +) -> Union[_OpReply, _OpMsg]: + """Receive a raw BSON message or raise socket.error.""" + timeout: Optional[Union[float, int]] + timeout = conn.conn.gettimeout + if _csot.get_timeout(): + deadline = _csot.get_deadline() + else: + if timeout: + deadline = time.monotonic() + timeout + else: + deadline = None + if deadline: + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + timeout = max(deadline - time.monotonic(), 0) + + cancellation_task = create_task(_poll_cancellation(conn)) + read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size)) + tasks = [read_task, cancellation_task] + try: + done, pending = await asyncio.wait( + tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + if pending: + await asyncio.wait(pending) + if len(done) == 0: + raise socket.timeout("timed out") + if read_task in done: + data, op_code = read_task.result() + try: + unpack_reply = _UNPACK_REPLY[op_code] + except KeyError: + raise ProtocolError( + f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" + ) from None + return unpack_reply(data) + raise _OperationCancelled("operation cancelled") + except asyncio.CancelledError: + for task in tasks: + task.cancel() + await asyncio.wait(tasks) + raise + + +def receive_message( + conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE +) -> Union[_OpReply, _OpMsg]: + """Receive a raw BSON message or raise socket.error.""" + if _csot.get_timeout(): + deadline = _csot.get_deadline() + else: + timeout = conn.conn.gettimeout() + if timeout: + deadline = time.monotonic() + timeout + else: + deadline = None + # Ignore the response's request id. + length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline)) + # No request_id for exhaust cursor "getMore". + if request_id is not None: + if request_id != response_to: + raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}") + if length <= 16: + raise ProtocolError( + f"Message length ({length!r}) not longer than standard message header size (16)" + ) + if length > max_message_size: + raise ProtocolError( + f"Message length ({length!r}) is larger than server max " + f"message size ({max_message_size!r})" + ) + if op_code == 2012: + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline)) + data = decompress(receive_data(conn, length - 25, deadline), compressor_id) + else: + data = receive_data(conn, length - 16, deadline) + + try: + unpack_reply = _UNPACK_REPLY[op_code] + except KeyError: + raise ProtocolError( + f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" + ) from None + return unpack_reply(data) diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py new file mode 100644 index 0000000000..42b330b1e2 --- /dev/null +++ b/pymongo/pool_shared.py @@ -0,0 +1,546 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pool utilities and shared helper methods.""" +from __future__ import annotations + +import asyncio +import functools +import socket +import ssl +import sys +from typing import ( + TYPE_CHECKING, + Any, + NoReturn, + Optional, + Union, +) + +from pymongo import _csot +from pymongo.asynchronous.helpers import _getaddrinfo +from pymongo.errors import ( # type:ignore[attr-defined] + AutoReconnect, + ConnectionFailure, + NetworkTimeout, + _CertificateError, +) +from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol +from pymongo.pool_options import PoolOptions +from pymongo.ssl_support import HAS_SNI, SSLError + +if TYPE_CHECKING: + from pymongo.pyopenssl_context import _sslConn + from pymongo.typings import _Address + +try: + from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl + + def _set_non_inheritable_non_atomic(fd: int) -> None: + """Set the close-on-exec flag on the given file descriptor.""" + flags = fcntl(fd, F_GETFD) + fcntl(fd, F_SETFD, flags | FD_CLOEXEC) + +except ImportError: + # Windows, various platforms we don't claim to support + # (Jython, IronPython, ..), systems that don't provide + # everything we need from fcntl, etc. + def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001 + """Dummy function for platforms that don't provide fcntl.""" + + +_MAX_TCP_KEEPIDLE = 120 +_MAX_TCP_KEEPINTVL = 10 +_MAX_TCP_KEEPCNT = 9 + +if sys.platform == "win32": + try: + import _winreg as winreg + except ImportError: + import winreg + + def _query(key, name, default): + try: + value, _ = winreg.QueryValueEx(key, name) + # Ensure the value is a number or raise ValueError. + return int(value) + except (OSError, ValueError): + # QueryValueEx raises OSError when the key does not exist (i.e. + # the system is using the Windows default value). + return default + + try: + with winreg.OpenKey( + winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters" + ) as key: + _WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000) + _WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000) + except OSError: + # We could not check the default values because winreg.OpenKey failed. + # Assume the system is using the default values. + _WINDOWS_TCP_IDLE_MS = 7200000 + _WINDOWS_TCP_INTERVAL_MS = 1000 + + def _set_keepalive_times(sock): + idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000) + interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000) + if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS: + sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms)) + +else: + + def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None: + if hasattr(socket, tcp_option): + sockopt = getattr(socket, tcp_option) + try: + # PYTHON-1350 - NetBSD doesn't implement getsockopt for + # TCP_KEEPIDLE and friends. Don't attempt to set the + # values there. + default = sock.getsockopt(socket.IPPROTO_TCP, sockopt) + if default > max_value: + sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value) + except OSError: + pass + + def _set_keepalive_times(sock: socket.socket) -> None: + _set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE) + _set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL) + _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) + + +def _raise_connection_failure( + address: Any, + error: Exception, + msg_prefix: Optional[str] = None, + timeout_details: Optional[dict[str, float]] = None, +) -> NoReturn: + """Convert a socket.error to ConnectionFailure and raise it.""" + host, port = address + # If connecting to a Unix socket, port will be None. + if port is not None: + msg = "%s:%d: %s" % (host, port, error) + else: + msg = f"{host}: {error}" + if msg_prefix: + msg = msg_prefix + msg + if "configured timeouts" not in msg: + msg += format_timeout_details(timeout_details) + if isinstance(error, socket.timeout): + raise NetworkTimeout(msg) from error + elif isinstance(error, SSLError) and "timed out" in str(error): + # Eventlet does not distinguish TLS network timeouts from other + # SSLErrors (https://github.com/eventlet/eventlet/issues/692). + # Luckily, we can work around this limitation because the phrase + # 'timed out' appears in all the timeout related SSLErrors raised. + raise NetworkTimeout(msg) from error + else: + raise AutoReconnect(msg) from error + + +def _get_timeout_details(options: PoolOptions) -> dict[str, float]: + details = {} + timeout = _csot.get_timeout() + socket_timeout = options.socket_timeout + connect_timeout = options.connect_timeout + if timeout: + details["timeoutMS"] = timeout * 1000 + if socket_timeout and not timeout: + details["socketTimeoutMS"] = socket_timeout * 1000 + if connect_timeout: + details["connectTimeoutMS"] = connect_timeout * 1000 + return details + + +def format_timeout_details(details: Optional[dict[str, float]]) -> str: + result = "" + if details: + result += " (configured timeouts:" + for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]: + if timeout in details: + result += f" {timeout}: {details[timeout]}ms," + result = result[:-1] + result += ")" + return result + + +class _CancellationContext: + def __init__(self) -> None: + self._cancelled = False + + def cancel(self) -> None: + """Cancel this context.""" + self._cancelled = True + + @property + def cancelled(self) -> bool: + """Was cancel called?""" + return self._cancelled + + +async def _async_create_connection(address: _Address, options: PoolOptions) -> socket.socket: + """Given (host, port) and PoolOptions, connect and return a raw socket object. + + Can raise socket.error. + + This is a modified version of create_connection from CPython >= 2.7. + """ + host, port = address + + # Check if dealing with a unix domain socket + if host.endswith(".sock"): + if not hasattr(socket, "AF_UNIX"): + raise ConnectionFailure("UNIX-sockets are not supported on this system") + sock = socket.socket(socket.AF_UNIX) + # SOCK_CLOEXEC not supported for Unix sockets. + _set_non_inheritable_non_atomic(sock.fileno()) + try: + sock.connect(host) + return sock + except OSError: + sock.close() + raise + + # Don't try IPv6 if we don't support it. Also skip it if host + # is 'localhost' (::1 is fine). Avoids slow connect issues + # like PYTHON-356. + family = socket.AF_INET + if socket.has_ipv6 and host != "localhost": + family = socket.AF_UNSPEC + + err = None + for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): + af, socktype, proto, dummy, sa = res + # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited + # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 + # all file descriptors are created non-inheritable. See PEP 446. + try: + sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto) + except OSError: + # Can SOCK_CLOEXEC be defined even if the kernel doesn't support + # it? + sock = socket.socket(af, socktype, proto) + # Fallback when SOCK_CLOEXEC isn't available. + _set_non_inheritable_non_atomic(sock.fileno()) + try: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + # CSOT: apply timeout to socket connect. + timeout = _csot.remaining() + if timeout is None: + timeout = options.connect_timeout + elif timeout <= 0: + raise socket.timeout("timed out") + sock.settimeout(timeout) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) + _set_keepalive_times(sock) + sock.connect(sa) + return sock + except OSError as e: + err = e + sock.close() + + if err is not None: + raise err + else: + # This likely means we tried to connect to an IPv6 only + # host with an OS/kernel or Python interpreter that doesn't + # support IPv6. The test case is Jython2.5.1 which doesn't + # support IPv6 at all. + raise OSError("getaddrinfo failed") + + +async def _async_configured_socket( + address: _Address, options: PoolOptions +) -> Union[socket.socket, _sslConn]: + """Given (host, port) and PoolOptions, return a raw configured socket. + + Can raise socket.error, ConnectionFailure, or _CertificateError. + + Sets socket's SSL and timeout options. + """ + sock = await _async_create_connection(address, options) + ssl_context = options._ssl_context + + if ssl_context is None: + sock.settimeout(options.socket_timeout) + return sock + + host = address[0] + try: + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + if HAS_SNI: + if hasattr(ssl_context, "a_wrap_socket"): + ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore] + else: + loop = asyncio.get_running_loop() + ssl_sock = await loop.run_in_executor( + None, + functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc, unused-ignore] + ) + else: + if hasattr(ssl_context, "a_wrap_socket"): + ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore] + else: + loop = asyncio.get_running_loop() + ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc, unused-ignore] + except _CertificateError: + sock.close() + # Raise _CertificateError directly like we do after match_hostname + # below. + raise + except (OSError, SSLError) as exc: + sock.close() + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + details = _get_timeout_details(options) + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore] + except _CertificateError: + ssl_sock.close() + raise + + ssl_sock.settimeout(options.socket_timeout) + return ssl_sock + + +async def _configured_protocol_interface( + address: _Address, options: PoolOptions +) -> AsyncNetworkingInterface: + """Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface. + + Can raise socket.error, ConnectionFailure, or _CertificateError. + + Sets protocol's SSL and timeout options. + """ + sock = await _async_create_connection(address, options) + ssl_context = options._ssl_context + timeout = options.socket_timeout + + if ssl_context is None: + return AsyncNetworkingInterface( + await asyncio.get_running_loop().create_connection( + lambda: PyMongoProtocol(timeout=timeout), sock=sock + ) + ) + + host = address[0] + try: + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload] + lambda: PyMongoProtocol(timeout=timeout), + sock=sock, + server_hostname=host, + ssl=ssl_context, + ) + except _CertificateError: + transport.abort() + # Raise _CertificateError directly like we do after match_hostname + # below. + raise + except (OSError, SSLError) as exc: + transport.abort() + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + details = _get_timeout_details(options) + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore] + except _CertificateError: + transport.abort() + raise + + return AsyncNetworkingInterface((transport, protocol)) + + +def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: + """Given (host, port) and PoolOptions, connect and return a raw socket object. + + Can raise socket.error. + + This is a modified version of create_connection from CPython >= 2.7. + """ + host, port = address + + # Check if dealing with a unix domain socket + if host.endswith(".sock"): + if not hasattr(socket, "AF_UNIX"): + raise ConnectionFailure("UNIX-sockets are not supported on this system") + sock = socket.socket(socket.AF_UNIX) + # SOCK_CLOEXEC not supported for Unix sockets. + _set_non_inheritable_non_atomic(sock.fileno()) + try: + sock.connect(host) + return sock + except OSError: + sock.close() + raise + + # Don't try IPv6 if we don't support it. Also skip it if host + # is 'localhost' (::1 is fine). Avoids slow connect issues + # like PYTHON-356. + family = socket.AF_INET + if socket.has_ipv6 and host != "localhost": + family = socket.AF_UNSPEC + + err = None + for res in socket.getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined, unused-ignore] + af, socktype, proto, dummy, sa = res + # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited + # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 + # all file descriptors are created non-inheritable. See PEP 446. + try: + sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto) + except OSError: + # Can SOCK_CLOEXEC be defined even if the kernel doesn't support + # it? + sock = socket.socket(af, socktype, proto) + # Fallback when SOCK_CLOEXEC isn't available. + _set_non_inheritable_non_atomic(sock.fileno()) + try: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + # CSOT: apply timeout to socket connect. + timeout = _csot.remaining() + if timeout is None: + timeout = options.connect_timeout + elif timeout <= 0: + raise socket.timeout("timed out") + sock.settimeout(timeout) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) + _set_keepalive_times(sock) + sock.connect(sa) + return sock + except OSError as e: + err = e + sock.close() + + if err is not None: + raise err + else: + # This likely means we tried to connect to an IPv6 only + # host with an OS/kernel or Python interpreter that doesn't + # support IPv6. The test case is Jython2.5.1 which doesn't + # support IPv6 at all. + raise OSError("getaddrinfo failed") + + +def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]: + """Given (host, port) and PoolOptions, return a raw configured socket. + + Can raise socket.error, ConnectionFailure, or _CertificateError. + + Sets socket's SSL and timeout options. + """ + sock = _create_connection(address, options) + ssl_context = options._ssl_context + + if ssl_context is None: + sock.settimeout(options.socket_timeout) + return sock + + host = address[0] + try: + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + if HAS_SNI: + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore] + else: + ssl_sock = ssl_context.wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore] + except _CertificateError: + sock.close() + # Raise _CertificateError directly like we do after match_hostname + # below. + raise + except (OSError, SSLError) as exc: + sock.close() + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + details = _get_timeout_details(options) + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore] + except _CertificateError: + ssl_sock.close() + raise + + ssl_sock.settimeout(options.socket_timeout) + return ssl_sock + + +def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface: + """Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket. + + Can raise socket.error, ConnectionFailure, or _CertificateError. + + Sets socket's SSL and timeout options. + """ + sock = _create_connection(address, options) + ssl_context = options._ssl_context + + if ssl_context is None: + sock.settimeout(options.socket_timeout) + return NetworkingInterface(sock) + + host = address[0] + try: + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + if HAS_SNI: + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) + else: + ssl_sock = ssl_context.wrap_socket(sock) + except _CertificateError: + sock.close() + # Raise _CertificateError directly like we do after match_hostname + # below. + raise + except (OSError, SSLError) as exc: + sock.close() + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + details = _get_timeout_details(options) + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined,unused-ignore] + except _CertificateError: + ssl_sock.close() + raise + + ssl_sock.settimeout(options.socket_timeout) + return NetworkingInterface(ssl_sock) diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 38c28de91e..ed631e135d 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -70,21 +70,21 @@ NetworkTimeout, ServerSelectionTimeoutError, ) -from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall +from pymongo.network_layer import sendall from pymongo.operations import UpdateOne from pymongo.pool_options import PoolOptions +from pymongo.pool_shared import ( + _configured_socket, + _get_timeout_details, + _raise_connection_failure, +) from pymongo.read_concern import ReadConcern from pymongo.results import BulkWriteResult, DeleteResult -from pymongo.ssl_support import get_ssl_context +from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context from pymongo.synchronous.collection import Collection from pymongo.synchronous.cursor import Cursor from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.pool import ( - _configured_socket, - _get_timeout_details, - _raise_connection_failure, -) from pymongo.typings import _DocumentType, _DocumentTypeArg from pymongo.uri_parser_shared import parse_host from pymongo.write_concern import WriteConcern diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index a2b76c4e8a..1413bb1437 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -36,7 +36,11 @@ from pymongo.synchronous.srv_resolver import _SrvResolver if TYPE_CHECKING: - from pymongo.synchronous.pool import Connection, Pool, _CancellationContext + from pymongo.synchronous.pool import ( # type: ignore[attr-defined] + Connection, + Pool, + _CancellationContext, + ) from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index 0e53e806b0..786edb7003 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -17,7 +17,6 @@ import datetime import logging -import time from typing import ( TYPE_CHECKING, Any, @@ -31,20 +30,16 @@ from bson import _decode_all_selective from pymongo import _csot, helpers_shared, message -from pymongo.common import MAX_MESSAGE_SIZE -from pymongo.compression_support import _NO_COMPRESSION, decompress +from pymongo.compression_support import _NO_COMPRESSION from pymongo.errors import ( NotPrimaryError, OperationFailure, - ProtocolError, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply +from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( - _UNPACK_COMPRESSION_HEADER, - _UNPACK_HEADER, - receive_data, + receive_message, sendall, ) @@ -194,7 +189,7 @@ def command( ) try: - sendall(conn.conn, msg) + sendall(conn.conn.get_conn, msg) if use_op_msg and unacknowledged: # Unacknowledged, fake a successful command response. reply = None @@ -301,45 +296,3 @@ def command( ) return response_doc # type: ignore[return-value] - - -def receive_message( - conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE -) -> Union[_OpReply, _OpMsg]: - """Receive a raw BSON message or raise socket.error.""" - if _csot.get_timeout(): - deadline = _csot.get_deadline() - else: - timeout = conn.conn.gettimeout() - if timeout: - deadline = time.monotonic() + timeout - else: - deadline = None - # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline)) - # No request_id for exhaust cursor "getMore". - if request_id is not None: - if request_id != response_to: - raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}") - if length <= 16: - raise ProtocolError( - f"Message length ({length!r}) not longer than standard message header size (16)" - ) - if length > max_message_size: - raise ProtocolError( - f"Message length ({length!r}) is larger than server max " - f"message size ({max_message_size!r})" - ) - if op_code == 2012: - op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline)) - data = decompress(receive_data(conn, length - 25, deadline), compressor_id) - else: - data = receive_data(conn, length - 16, deadline) - - try: - unpack_reply = _UNPACK_REPLY[op_code] - except KeyError: - raise ProtocolError( - f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" - ) from None - return unpack_reply(data) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index cd78e26fea..6a302e2728 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -14,14 +14,10 @@ from __future__ import annotations -import asyncio import collections import contextlib -import functools import logging import os -import socket -import ssl import sys import time import weakref @@ -49,16 +45,13 @@ from pymongo.errors import ( # type:ignore[attr-defined] AutoReconnect, ConfigurationError, - ConnectionFailure, DocumentTooLarge, ExecutionTimeout, InvalidOperation, - NetworkTimeout, NotPrimaryError, OperationFailure, PyMongoError, WaitQueueTimeoutError, - _CertificateError, ) from pymongo.hello import Hello, HelloCompat from pymongo.lock import ( @@ -76,16 +69,23 @@ ConnectionCheckOutFailedReason, ConnectionClosedReason, ) -from pymongo.network_layer import sendall +from pymongo.network_layer import NetworkingInterface, receive_message, sendall from pymongo.pool_options import PoolOptions +from pymongo.pool_shared import ( + _CancellationContext, + _configured_socket_interface, + _get_timeout_details, + _raise_connection_failure, + format_timeout_details, +) from pymongo.read_preferences import ReadPreference from pymongo.server_api import _add_to_command from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker -from pymongo.ssl_support import HAS_SNI, SSLError +from pymongo.ssl_support import SSLError from pymongo.synchronous.client_session import _validate_session_write_concern -from pymongo.synchronous.helpers import _getaddrinfo, _handle_reauth -from pymongo.synchronous.network import command, receive_message +from pymongo.synchronous.helpers import _handle_reauth +from pymongo.synchronous.network import command if TYPE_CHECKING: from bson import CodecOptions @@ -96,7 +96,6 @@ ZstdContext, ) from pymongo.message import _OpMsg, _OpReply - from pymongo.pyopenssl_context import _sslConn from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode from pymongo.synchronous.auth import _AuthContext @@ -123,133 +122,6 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001 _IS_SYNC = True -_MAX_TCP_KEEPIDLE = 120 -_MAX_TCP_KEEPINTVL = 10 -_MAX_TCP_KEEPCNT = 9 - -if sys.platform == "win32": - try: - import _winreg as winreg - except ImportError: - import winreg - - def _query(key, name, default): - try: - value, _ = winreg.QueryValueEx(key, name) - # Ensure the value is a number or raise ValueError. - return int(value) - except (OSError, ValueError): - # QueryValueEx raises OSError when the key does not exist (i.e. - # the system is using the Windows default value). - return default - - try: - with winreg.OpenKey( - winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters" - ) as key: - _WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000) - _WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000) - except OSError: - # We could not check the default values because winreg.OpenKey failed. - # Assume the system is using the default values. - _WINDOWS_TCP_IDLE_MS = 7200000 - _WINDOWS_TCP_INTERVAL_MS = 1000 - - def _set_keepalive_times(sock): - idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000) - interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000) - if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS: - sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms)) - -else: - - def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None: - if hasattr(socket, tcp_option): - sockopt = getattr(socket, tcp_option) - try: - # PYTHON-1350 - NetBSD doesn't implement getsockopt for - # TCP_KEEPIDLE and friends. Don't attempt to set the - # values there. - default = sock.getsockopt(socket.IPPROTO_TCP, sockopt) - if default > max_value: - sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value) - except OSError: - pass - - def _set_keepalive_times(sock: socket.socket) -> None: - _set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE) - _set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL) - _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) - - -def _raise_connection_failure( - address: Any, - error: Exception, - msg_prefix: Optional[str] = None, - timeout_details: Optional[dict[str, float]] = None, -) -> NoReturn: - """Convert a socket.error to ConnectionFailure and raise it.""" - host, port = address - # If connecting to a Unix socket, port will be None. - if port is not None: - msg = "%s:%d: %s" % (host, port, error) - else: - msg = f"{host}: {error}" - if msg_prefix: - msg = msg_prefix + msg - if "configured timeouts" not in msg: - msg += format_timeout_details(timeout_details) - if isinstance(error, socket.timeout): - raise NetworkTimeout(msg) from error - elif isinstance(error, SSLError) and "timed out" in str(error): - # Eventlet does not distinguish TLS network timeouts from other - # SSLErrors (https://github.com/eventlet/eventlet/issues/692). - # Luckily, we can work around this limitation because the phrase - # 'timed out' appears in all the timeout related SSLErrors raised. - raise NetworkTimeout(msg) from error - else: - raise AutoReconnect(msg) from error - - -def _get_timeout_details(options: PoolOptions) -> dict[str, float]: - details = {} - timeout = _csot.get_timeout() - socket_timeout = options.socket_timeout - connect_timeout = options.connect_timeout - if timeout: - details["timeoutMS"] = timeout * 1000 - if socket_timeout and not timeout: - details["socketTimeoutMS"] = socket_timeout * 1000 - if connect_timeout: - details["connectTimeoutMS"] = connect_timeout * 1000 - return details - - -def format_timeout_details(details: Optional[dict[str, float]]) -> str: - result = "" - if details: - result += " (configured timeouts:" - for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]: - if timeout in details: - result += f" {timeout}: {details[timeout]}ms," - result = result[:-1] - result += ")" - return result - - -class _CancellationContext: - def __init__(self) -> None: - self._cancelled = False - - def cancel(self) -> None: - """Cancel this context.""" - self._cancelled = True - - @property - def cancelled(self) -> bool: - """Was cancel called?""" - return self._cancelled - class Connection: """Store a connection with some metadata. @@ -261,7 +133,11 @@ class Connection: """ def __init__( - self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int + self, + conn: NetworkingInterface, + pool: Pool, + address: tuple[str, int], + id: int, ): self.pool_ref = weakref.ref(pool) self.conn = conn @@ -318,7 +194,7 @@ def set_conn_timeout(self, timeout: Optional[float]) -> None: if timeout == self.last_timeout: return self.last_timeout = timeout - self.conn.settimeout(timeout) + self.conn.get_conn.settimeout(timeout) def apply_timeout( self, client: MongoClient, cmd: Optional[MutableMapping[str, Any]] @@ -573,7 +449,7 @@ def send_message(self, message: bytes, max_doc_size: int) -> None: ) try: - sendall(self.conn, message) + sendall(self.conn.get_conn, message) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -707,7 +583,10 @@ def _close_conn(self) -> None: def conn_closed(self) -> bool: """Return True if we know socket has been closed, False otherwise.""" - return self.socket_checker.socket_closed(self.conn) + if _IS_SYNC: + return self.socket_checker.socket_closed(self.conn.get_conn) + else: + return self.conn.is_closing() def send_cluster_time( self, @@ -779,143 +658,6 @@ def __repr__(self) -> str: ) -def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: - """Given (host, port) and PoolOptions, connect and return a socket object. - - Can raise socket.error. - - This is a modified version of create_connection from CPython >= 2.7. - """ - host, port = address - - # Check if dealing with a unix domain socket - if host.endswith(".sock"): - if not hasattr(socket, "AF_UNIX"): - raise ConnectionFailure("UNIX-sockets are not supported on this system") - sock = socket.socket(socket.AF_UNIX) - # SOCK_CLOEXEC not supported for Unix sockets. - _set_non_inheritable_non_atomic(sock.fileno()) - try: - sock.connect(host) - return sock - except OSError: - sock.close() - raise - - # Don't try IPv6 if we don't support it. Also skip it if host - # is 'localhost' (::1 is fine). Avoids slow connect issues - # like PYTHON-356. - family = socket.AF_INET - if socket.has_ipv6 and host != "localhost": - family = socket.AF_UNSPEC - - err = None - for res in _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined] - af, socktype, proto, dummy, sa = res - # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited - # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 - # all file descriptors are created non-inheritable. See PEP 446. - try: - sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto) - except OSError: - # Can SOCK_CLOEXEC be defined even if the kernel doesn't support - # it? - sock = socket.socket(af, socktype, proto) - # Fallback when SOCK_CLOEXEC isn't available. - _set_non_inheritable_non_atomic(sock.fileno()) - try: - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - # CSOT: apply timeout to socket connect. - timeout = _csot.remaining() - if timeout is None: - timeout = options.connect_timeout - elif timeout <= 0: - raise socket.timeout("timed out") - sock.settimeout(timeout) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) - _set_keepalive_times(sock) - sock.connect(sa) - return sock - except OSError as e: - err = e - sock.close() - - if err is not None: - raise err - else: - # This likely means we tried to connect to an IPv6 only - # host with an OS/kernel or Python interpreter that doesn't - # support IPv6. The test case is Jython2.5.1 which doesn't - # support IPv6 at all. - raise OSError("getaddrinfo failed") - - -def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]: - """Given (host, port) and PoolOptions, return a configured socket. - - Can raise socket.error, ConnectionFailure, or _CertificateError. - - Sets socket's SSL and timeout options. - """ - sock = _create_connection(address, options) - ssl_context = options._ssl_context - - if ssl_context is None: - sock.settimeout(options.socket_timeout) - return sock - - host = address[0] - try: - # We have to pass hostname / ip address to wrap_socket - # to use SSLContext.check_hostname. - if HAS_SNI: - if _IS_SYNC: - ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) - else: - if hasattr(ssl_context, "a_wrap_socket"): - ssl_sock = ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc] - else: - loop = asyncio.get_running_loop() - ssl_sock = loop.run_in_executor( - None, - functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc] - ) - else: - if _IS_SYNC: - ssl_sock = ssl_context.wrap_socket(sock) - else: - if hasattr(ssl_context, "a_wrap_socket"): - ssl_sock = ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc] - else: - loop = asyncio.get_running_loop() - ssl_sock = loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc] - except _CertificateError: - sock.close() - # Raise _CertificateError directly like we do after match_hostname - # below. - raise - except (OSError, SSLError) as exc: - sock.close() - # We raise AutoReconnect for transient and permanent SSL handshake - # failures alike. Permanent handshake failures, like protocol - # mismatch, will be turned into ServerSelectionTimeoutErrors later. - details = _get_timeout_details(options) - _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: - ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined] - except _CertificateError: - ssl_sock.close() - raise - - ssl_sock.settimeout(options.socket_timeout) - return ssl_sock - - class _PoolClosedError(PyMongoError): """Internal error raised when a thread tries to get a connection from a closed pool. @@ -1260,7 +1002,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect ) try: - sock = _configured_socket(self.address, self.opts) + networking_interface = _configured_socket_interface(self.address, self.opts) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: with self.lock: @@ -1287,7 +1029,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect raise - conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type] + conn = Connection(networking_interface, self, self.address, conn_id) # type: ignore[arg-type] with self.lock: self.active_contexts.add(conn.cancel_context) self.active_contexts.discard(tmp_context) @@ -1698,5 +1440,6 @@ def __del__(self) -> None: # Avoid ResourceWarnings in Python 3 # Close all sockets without calling reset() or close() because it is # not safe to acquire a lock in __del__. - for conn in self.conns: - conn.close_conn(None) + if _IS_SYNC: + for conn in self.conns: + conn.close_conn(None) diff --git a/pyproject.toml b/pyproject.toml index 353f527879..611cac13aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,6 +116,7 @@ filterwarnings = [ "module:unclosed None: self.deprecation_filter = DeprecationFilter() async def asyncTearDown(self) -> None: + await super().asyncTearDown() self.deprecation_filter.stop() @@ -196,6 +197,7 @@ async def asyncTearDown(self): SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) ) self.knobs.disable() + await super().asyncTearDown() async def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() diff --git a/test/asynchronous/utils.py b/test/asynchronous/utils.py index 4b68595397..f653c575e9 100644 --- a/test/asynchronous/utils.py +++ b/test/asynchronous/utils.py @@ -159,6 +159,7 @@ def __init__(self): self.cancel_context = _CancellationContext() self.more_to_come = False self.id = random.randint(0, 100) + self.server_connection_id = random.randint(0, 100) def close_conn(self, reason): pass diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 3c3a1a67ae..9ba15e8d78 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -22,6 +22,8 @@ import warnings from test import PyMongoTestCase +import pytest + sys.path[0:0] = [""] from test import unittest @@ -30,6 +32,8 @@ from pymongo import MongoClient from pymongo.synchronous.auth_oidc import OIDCCallback +pytestmark = pytest.mark.auth + _IS_SYNC = True _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") diff --git a/test/test_client.py b/test/test_client.py index cd4ceb3299..038ba2241b 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1233,7 +1233,6 @@ def test_socket_timeout(self): no_timeout = self.client timeout_sec = 1 timeout = self.rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) - self.addCleanup(timeout.close) no_timeout.pymongo_test.drop_collection("test") no_timeout.pymongo_test.test.insert_one({"x": 1}) @@ -1296,7 +1295,7 @@ def test_waitQueueTimeoutMS(self): def test_socketKeepAlive(self): pool = get_pool(self.client) with pool.checkout() as conn: - keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) + keepalive = conn.conn.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) self.assertTrue(keepalive) @no_type_check diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index b00b2c1b03..866b179c9e 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -647,7 +647,6 @@ def test_timeout_in_multi_batch_bulk_write(self): _OVERHEAD = 500 internal_client = self.rs_or_single_client(timeoutMS=None) - self.addCleanup(internal_client.close) collection = internal_client.db["coll"] self.addCleanup(collection.drop) diff --git a/test/test_cursor.py b/test/test_cursor.py index a9cbe99942..7b75f4ddc4 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1801,6 +1801,7 @@ def test_monitoring(self): @client_context.require_version_min(5, 0, -1) @client_context.require_no_mongos + @client_context.require_sync def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) @@ -1810,7 +1811,7 @@ def test_exhaust_cursor_db_set(self): listener.reset() - result = c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1).to_list() + result = list(c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1)) self.assertEqual(len(result), 3) diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index b099820a45..598fc3fd76 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -137,6 +137,7 @@ def setUp(self) -> None: self.deprecation_filter = DeprecationFilter() def tearDown(self) -> None: + super().tearDown() self.deprecation_filter.stop() @@ -194,6 +195,7 @@ def tearDown(self): SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) ) self.knobs.disable() + super().tearDown() def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() diff --git a/test/utils.py b/test/utils.py index 1459a8fba7..3027ed7517 100644 --- a/test/utils.py +++ b/test/utils.py @@ -157,6 +157,7 @@ def __init__(self): self.cancel_context = _CancellationContext() self.more_to_come = False self.id = random.randint(0, 100) + self.server_connection_id = random.randint(0, 100) def close_conn(self, reason): pass diff --git a/tools/synchro.py b/tools/synchro.py index d8760b83bc..f451d09a26 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -47,6 +47,7 @@ "async_receive_message": "receive_message", "async_receive_data": "receive_data", "async_sendall": "sendall", + "async_socket_sendall": "sendall", "asynchronous": "synchronous", "Asynchronous": "Synchronous", "AsyncBulkTestBase": "BulkTestBase", @@ -119,6 +120,9 @@ "_async_create_lock": "_create_lock", "_async_create_condition": "_create_condition", "_async_cond_wait": "_cond_wait", + "AsyncNetworkingInterface": "NetworkingInterface", + "_configured_protocol_interface": "_configured_socket_interface", + "_async_configured_socket": "_configured_socket", "SpecRunnerTask": "SpecRunnerThread", "AsyncMockConnection": "MockConnection", "AsyncMockPool": "MockPool", @@ -127,6 +131,7 @@ "async_create_barrier": "create_barrier", "async_barrier_wait": "barrier_wait", "async_joinall": "joinall", + "_async_create_connection": "_create_connection", "pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts", }