diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 149cb3ac85..613061e188 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -64,6 +64,7 @@ from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.asynchronous.pool import AsyncBaseConnection from pymongo.common import CONNECT_TIMEOUT from pymongo.daemon import _spawn_daemon from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts @@ -75,11 +76,11 @@ NetworkTimeout, ServerSelectionTimeoutError, ) -from pymongo.network_layer import async_socket_sendall +from pymongo.network_layer import PyMongoKMSProtocol, async_receive_kms, async_sendall from pymongo.operations import UpdateOne from pymongo.pool_options import PoolOptions from pymongo.pool_shared import ( - _async_configured_socket, + _configured_protocol_interface, _get_timeout_details, _raise_connection_failure, ) @@ -93,10 +94,8 @@ if TYPE_CHECKING: from pymongocrypt.mongocrypt import MongoCryptKmsContext - from pymongo.pyopenssl_context import _sslConn from pymongo.typings import _Address - _IS_SYNC = False _HTTPS_PORT = 443 @@ -111,9 +110,10 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument) -async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]: +async def _connect_kms(address: _Address, opts: PoolOptions) -> AsyncBaseConnection: try: - return await _async_configured_socket(address, opts) + interface = await _configured_protocol_interface(address, opts, PyMongoKMSProtocol) + return AsyncBaseConnection(interface, opts) except Exception as exc: _raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts)) @@ -198,18 +198,11 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: try: conn = await _connect_kms(address, opts) try: - await async_socket_sendall(conn, message) + await async_sendall(conn.conn.get_conn, message) while kms_context.bytes_needed > 0: # CSOT: update timeout. - conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - if _IS_SYNC: - data = conn.recv(kms_context.bytes_needed) - else: - from pymongo.network_layer import ( # type: ignore[attr-defined] - async_receive_data_socket, - ) - - data = await async_receive_data_socket(conn, kms_context.bytes_needed) + conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + data = await async_receive_kms(conn, kms_context.bytes_needed) if not data: raise OSError("KMS connection closed") kms_context.feed(data) @@ -228,7 +221,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts) ) finally: - conn.close() + await conn.close_conn(None) except MongoCryptError: raise # Propagate MongoCryptError errors directly. except Exception as exc: diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index e215cafdc1..a57bd98451 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -124,7 +124,89 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001 _IS_SYNC = False -class AsyncConnection: +class AsyncBaseConnection: + """A base connection object for server and kms connections.""" + + def __init__(self, conn: AsyncNetworkingInterface, opts: PoolOptions): + self.conn = conn + self.socket_checker: SocketChecker = SocketChecker() + self.cancel_context: _CancellationContext = _CancellationContext() + self.is_sdam = False + self.closed = False + self.last_timeout: float | None = None + self.more_to_come = False + self.opts = opts + self.max_wire_version = -1 + + def set_conn_timeout(self, timeout: Optional[float]) -> None: + """Cache last timeout to avoid duplicate calls to conn.settimeout.""" + if timeout == self.last_timeout: + return + self.last_timeout = timeout + self.conn.get_conn.settimeout(timeout) + + def apply_timeout( + self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]] + ) -> Optional[float]: + # CSOT: use remaining timeout when set. + timeout = _csot.remaining() + if timeout is None: + # Reset the socket timeout unless we're performing a streaming monitor check. + if not self.more_to_come: + self.set_conn_timeout(self.opts.socket_timeout) + return None + # RTT validation. + rtt = _csot.get_rtt() + if rtt is None: + rtt = self.connect_rtt + max_time_ms = timeout - rtt + if max_time_ms < 0: + timeout_details = _get_timeout_details(self.opts) + formatted = format_timeout_details(timeout_details) + # CSOT: raise an error without running the command since we know it will time out. + errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}" + if self.max_wire_version != -1: + raise ExecutionTimeout( + errmsg, + 50, + {"ok": 0, "errmsg": errmsg, "code": 50}, + self.max_wire_version, + ) + else: + raise TimeoutError(errmsg) + if cmd is not None: + cmd["maxTimeMS"] = int(max_time_ms * 1000) + self.set_conn_timeout(timeout) + return timeout + + async def close_conn(self, reason: Optional[str]) -> None: + """Close this connection with a reason.""" + if self.closed: + return + await self._close_conn() + + async def _close_conn(self) -> None: + """Close this connection.""" + if self.closed: + return + self.closed = True + self.cancel_context.cancel() + # Note: We catch exceptions to avoid spurious errors on interpreter + # shutdown. + try: + await self.conn.close() + except Exception: # noqa: S110 + pass + + def conn_closed(self) -> bool: + """Return True if we know socket has been closed, False otherwise.""" + if _IS_SYNC: + return self.socket_checker.socket_closed(self.conn.get_conn) + else: + return self.conn.is_closing() + + +class AsyncConnection(AsyncBaseConnection): """Store a connection with some metadata. :param conn: a raw connection object @@ -142,29 +224,27 @@ def __init__( id: int, is_sdam: bool, ): + super().__init__(conn, pool.opts) self.pool_ref = weakref.ref(pool) - self.conn = conn - self.address = address - self.id = id + self.address: tuple[str, int] = address + self.id: int = id self.is_sdam = is_sdam - self.closed = False self.last_checkin_time = time.monotonic() self.performed_handshake = False self.is_writable: bool = False self.max_wire_version = MAX_WIRE_VERSION - self.max_bson_size = MAX_BSON_SIZE - self.max_message_size = MAX_MESSAGE_SIZE - self.max_write_batch_size = MAX_WRITE_BATCH_SIZE + self.max_bson_size: int = MAX_BSON_SIZE + self.max_message_size: int = MAX_MESSAGE_SIZE + self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE self.supports_sessions = False self.hello_ok: bool = False - self.is_mongos = False + self.is_mongos: bool = False self.op_msg_enabled = False self.listeners = pool.opts._event_listeners self.enabled_for_cmap = pool.enabled_for_cmap self.enabled_for_logging = pool.enabled_for_logging self.compression_settings = pool.opts._compression_settings self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None - self.socket_checker: SocketChecker = SocketChecker() self.oidc_token_gen_id: Optional[int] = None # Support for mechanism negotiation on the initial handshake. self.negotiated_mechs: Optional[list[str]] = None @@ -175,9 +255,6 @@ def __init__( self.pool_gen = pool.gen self.generation = self.pool_gen.get_overall() self.ready = False - self.cancel_context: _CancellationContext = _CancellationContext() - self.opts = pool.opts - self.more_to_come: bool = False # For load balancer support. self.service_id: Optional[ObjectId] = None self.server_connection_id: Optional[int] = None @@ -193,44 +270,6 @@ def __init__( # For gossiping $clusterTime from the connection handshake to the client. self._cluster_time = None - def set_conn_timeout(self, timeout: Optional[float]) -> None: - """Cache last timeout to avoid duplicate calls to conn.settimeout.""" - if timeout == self.last_timeout: - return - self.last_timeout = timeout - self.conn.get_conn.settimeout(timeout) - - def apply_timeout( - self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]] - ) -> Optional[float]: - # CSOT: use remaining timeout when set. - timeout = _csot.remaining() - if timeout is None: - # Reset the socket timeout unless we're performing a streaming monitor check. - if not self.more_to_come: - self.set_conn_timeout(self.opts.socket_timeout) - return None - # RTT validation. - rtt = _csot.get_rtt() - if rtt is None: - rtt = self.connect_rtt - max_time_ms = timeout - rtt - if max_time_ms < 0: - timeout_details = _get_timeout_details(self.opts) - formatted = format_timeout_details(timeout_details) - # CSOT: raise an error without running the command since we know it will time out. - errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}" - raise ExecutionTimeout( - errmsg, - 50, - {"ok": 0, "errmsg": errmsg, "code": 50}, - self.max_wire_version, - ) - if cmd is not None: - cmd["maxTimeMS"] = int(max_time_ms * 1000) - self.set_conn_timeout(timeout) - return timeout - def pin_txn(self) -> None: self.pinned_txn = True assert not self.pinned_cursor @@ -574,26 +613,6 @@ async def close_conn(self, reason: Optional[str]) -> None: error=reason, ) - async def _close_conn(self) -> None: - """Close this connection.""" - if self.closed: - return - self.closed = True - self.cancel_context.cancel() - # Note: We catch exceptions to avoid spurious errors on interpreter - # shutdown. - try: - await self.conn.close() - except Exception: # noqa: S110 - pass - - def conn_closed(self) -> bool: - """Return True if we know socket has been closed, False otherwise.""" - if _IS_SYNC: - return self.socket_checker.socket_closed(self.conn.get_conn) - else: - return self.conn.is_closing() - def send_cluster_time( self, command: MutableMapping[str, Any], diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 2f7f9c320f..6e4185adf7 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -22,10 +22,12 @@ import struct import sys import time -from asyncio import AbstractEventLoop, BaseTransport, BufferedProtocol, Future, Transport +from asyncio import BaseTransport, BufferedProtocol, Future, Transport +from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, + Callable, Optional, Union, ) @@ -38,208 +40,30 @@ from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.socket_checker import _errno_from_exception -try: - from ssl import SSLError, SSLSocket - - _HAVE_SSL = True -except ImportError: - _HAVE_SSL = False - -try: - from pymongo.pyopenssl_context import _sslConn - - _HAVE_PYOPENSSL = True -except ImportError: - _HAVE_PYOPENSSL = False - _sslConn = SSLSocket # type: ignore[assignment, misc] - -from pymongo.ssl_support import ( - BLOCKING_IO_LOOKUP_ERROR, - BLOCKING_IO_READ_ERROR, - BLOCKING_IO_WRITE_ERROR, -) - if TYPE_CHECKING: - from pymongo.asynchronous.pool import AsyncConnection - from pymongo.synchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncBaseConnection, AsyncConnection + from pymongo.pyopenssl_context import _sslConn + from pymongo.synchronous.pool import BaseConnection, Connection _UNPACK_HEADER = struct.Struct(" None: - timeout = sock.gettimeout() - sock.settimeout(0.0) - loop = asyncio.get_running_loop() - try: - if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - await asyncio.wait_for(_async_socket_sendall_ssl(sock, buf, loop), timeout=timeout) - else: - await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type] - except asyncio.TimeoutError as exc: - # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands. - raise socket.timeout("timed out") from exc - finally: - sock.settimeout(timeout) - - -if sys.platform != "win32": - - async def _async_socket_sendall_ssl( - sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop - ) -> None: - view = memoryview(buf) - sent = 0 - - def _is_ready(fut: Future[Any]) -> None: - if fut.done(): - return - fut.set_result(None) - - while sent < len(buf): - try: - sent += sock.send(view[sent:]) - except BLOCKING_IO_ERRORS as exc: - fd = sock.fileno() - # Check for closed socket. - if fd == -1: - raise SSLError("Underlying socket has been closed") from None - if isinstance(exc, BLOCKING_IO_READ_ERROR): - fut = loop.create_future() - loop.add_reader(fd, _is_ready, fut) - try: - await fut - finally: - loop.remove_reader(fd) - if isinstance(exc, BLOCKING_IO_WRITE_ERROR): - fut = loop.create_future() - loop.add_writer(fd, _is_ready, fut) - try: - await fut - finally: - loop.remove_writer(fd) - if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): - fut = loop.create_future() - loop.add_reader(fd, _is_ready, fut) - try: - loop.add_writer(fd, _is_ready, fut) - await fut - finally: - loop.remove_reader(fd) - loop.remove_writer(fd) - - async def _async_socket_receive_ssl( - conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False - ) -> memoryview: - mv = memoryview(bytearray(length)) - total_read = 0 - - def _is_ready(fut: Future[Any]) -> None: - if fut.done(): - return - fut.set_result(None) +_PYPY = "PyPy" in sys.version +_WINDOWS = sys.platform == "win32" - while total_read < length: - try: - read = conn.recv_into(mv[total_read:]) - if read == 0: - raise OSError("connection closed") - # KMS responses update their expected size after the first batch, stop reading after one loop - if once: - return mv[:read] - total_read += read - except BLOCKING_IO_ERRORS as exc: - fd = conn.fileno() - # Check for closed socket. - if fd == -1: - raise SSLError("Underlying socket has been closed") from None - if isinstance(exc, BLOCKING_IO_READ_ERROR): - fut = loop.create_future() - loop.add_reader(fd, _is_ready, fut) - try: - await fut - finally: - loop.remove_reader(fd) - if isinstance(exc, BLOCKING_IO_WRITE_ERROR): - fut = loop.create_future() - loop.add_writer(fd, _is_ready, fut) - try: - await fut - finally: - loop.remove_writer(fd) - if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): - fut = loop.create_future() - loop.add_reader(fd, _is_ready, fut) - try: - loop.add_writer(fd, _is_ready, fut) - await fut - finally: - loop.remove_reader(fd) - loop.remove_writer(fd) - return mv - -else: - # The default Windows asyncio event loop does not support loop.add_reader/add_writer: - # https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support - # Note: In PYTHON-4493 we plan to replace this code with asyncio streams. - async def _async_socket_sendall_ssl( - sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop - ) -> None: - view = memoryview(buf) - total_length = len(buf) - total_sent = 0 - # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success - # down to 1ms. - backoff = 0.001 - while total_sent < total_length: - try: - sent = sock.send(view[total_sent:]) - except BLOCKING_IO_ERRORS: - await asyncio.sleep(backoff) - sent = 0 - if sent > 0: - backoff = max(backoff / 2, 0.001) - else: - backoff = min(backoff * 2, 0.512) - total_sent += sent - - async def _async_socket_receive_ssl( - conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False - ) -> memoryview: - mv = memoryview(bytearray(length)) - total_read = 0 - # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success - # down to 1ms. - backoff = 0.001 - while total_read < length: - try: - read = conn.recv_into(mv[total_read:]) - if read == 0: - raise OSError("connection closed") - # KMS responses update their expected size after the first batch, stop reading after one loop - if once: - return mv[:read] - except BLOCKING_IO_ERRORS: - await asyncio.sleep(backoff) - read = 0 - if read > 0: - backoff = max(backoff / 2, 0.001) - else: - backoff = min(backoff * 2, 0.512) - total_read += read - return mv +# Errors raised by sockets (and TLS sockets) when in non-blocking mode. +BLOCKING_IO_ERRORS = ( + BlockingIOError, + *ssl_support.BLOCKING_IO_LOOKUP_ERROR, + *ssl_support.BLOCKING_IO_ERRORS, +) def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: sock.sendall(buf) -async def _poll_cancellation(conn: AsyncConnection) -> None: +async def _poll_cancellation(conn: AsyncBaseConnection) -> None: while True: if conn.cancel_context.cancelled: return @@ -247,49 +71,7 @@ async def _poll_cancellation(conn: AsyncConnection) -> None: await asyncio.sleep(_POLL_TIMEOUT) -async def async_receive_data_socket( - sock: Union[socket.socket, _sslConn], length: int -) -> memoryview: - sock_timeout = sock.gettimeout() - timeout = sock_timeout - - sock.settimeout(0.0) - loop = asyncio.get_running_loop() - try: - if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - return await asyncio.wait_for( - _async_socket_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type] - timeout=timeout, - ) - else: - return await asyncio.wait_for( - _async_socket_receive(sock, length, loop), # type: ignore[arg-type] - timeout=timeout, - ) - except asyncio.TimeoutError as err: - raise socket.timeout("timed out") from err - finally: - sock.settimeout(sock_timeout) - - -async def _async_socket_receive( - conn: socket.socket, length: int, loop: AbstractEventLoop -) -> memoryview: - mv = memoryview(bytearray(length)) - bytes_read = 0 - while bytes_read < length: - chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:]) - if chunk_length == 0: - raise OSError("connection closed") - bytes_read += chunk_length - return mv - - -_PYPY = "PyPy" in sys.version -_WINDOWS = sys.platform == "win32" - - -def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: +def wait_for_read(conn: BaseConnection, deadline: Optional[float]) -> None: """Block until at least one byte is read, or a timeout, or a cancel.""" sock = conn.conn.sock timed_out = False @@ -322,7 +104,7 @@ def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: raise socket.timeout("timed out") -def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: +def receive_data(conn: BaseConnection, length: int, deadline: Optional[float]) -> memoryview: buf = bytearray(length) mv = memoryview(buf) bytes_read = 0 @@ -412,7 +194,7 @@ def sock(self) -> Any: class AsyncNetworkingInterface(NetworkingInterfaceBase): - def __init__(self, conn: tuple[Transport, PyMongoProtocol]): + def __init__(self, conn: tuple[Transport, PyMongoBaseProtocol]): super().__init__(conn) @property @@ -430,7 +212,7 @@ def is_closing(self) -> bool: return self.conn[0].is_closing() @property - def get_conn(self) -> PyMongoProtocol: + def get_conn(self) -> PyMongoBaseProtocol: return self.conn[1] @property @@ -469,9 +251,51 @@ def recv_into(self, buffer: bytes) -> int: return self.conn.recv_into(buffer) -class PyMongoProtocol(BufferedProtocol): +class PyMongoBaseProtocol(BufferedProtocol): def __init__(self, timeout: Optional[float] = None): self.transport: Transport = None # type: ignore[assignment] + self._timeout = timeout + self._closed = asyncio.get_running_loop().create_future() + self._connection_lost = False + + def settimeout(self, timeout: float | None) -> None: + self._timeout = timeout + + @property + def gettimeout(self) -> float | None: + """The configured timeout for the socket that underlies our protocol pair.""" + return self._timeout + + def close(self, exc: Optional[Exception] = None) -> None: + self.transport.abort() + self._resolve_pending(exc) + self._connection_lost = True + + def connection_lost(self, exc: Optional[Exception] = None) -> None: + self._resolve_pending(exc) + if not self._closed.done(): + self._closed.set_result(None) + + def _resolve_pending(self, exc: Optional[Exception] = None) -> None: + pass + + async def wait_closed(self) -> None: + await self._closed + + async def write(self, message: bytes) -> None: + """Write a message to this connection's transport.""" + if self.transport.is_closing(): + raise OSError("Connection is closed") + self.transport.write(message) + self.transport.resume_reading() + + async def read(self, *args: Any) -> Any: + raise NotImplementedError + + +class PyMongoProtocol(PyMongoBaseProtocol): + def __init__(self, timeout: Optional[float] = None): + super().__init__(timeout) # Each message is reader in 2-3 parts: header, compression header, and message body # The message buffer is allocated after the header is read. self._header = memoryview(bytearray(16)) @@ -485,25 +309,14 @@ def __init__(self, timeout: Optional[float] = None): self._expecting_compression = False self._message_size = 0 self._op_code = 0 - self._connection_lost = False self._read_waiter: Optional[Future[Any]] = None - self._timeout = timeout self._is_compressed = False self._compressor_id: Optional[int] = None self._max_message_size = MAX_MESSAGE_SIZE self._response_to: Optional[int] = None - self._closed = asyncio.get_running_loop().create_future() self._pending_messages: collections.deque[Future[Any]] = collections.deque() self._done_messages: collections.deque[Future[Any]] = collections.deque() - def settimeout(self, timeout: float | None) -> None: - self._timeout = timeout - - @property - def gettimeout(self) -> float | None: - """The configured timeout for the socket that underlies our protocol pair.""" - return self._timeout - def connection_made(self, transport: BaseTransport) -> None: """Called exactly once when a connection is made. The transport argument is the transport representing the write side of the connection. @@ -511,13 +324,6 @@ def connection_made(self, transport: BaseTransport) -> None: self.transport = transport # type: ignore[assignment] self.transport.set_write_buffer_limits(MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE) - async def write(self, message: bytes) -> None: - """Write a message to this connection's transport.""" - if self.transport.is_closing(): - raise OSError("Connection is closed") - self.transport.write(message) - self.transport.resume_reading() - async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]: """Read a single MongoDB Wire Protocol message from this connection.""" if self.transport: @@ -660,7 +466,7 @@ def process_compression_header(self) -> tuple[int, int]: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header) return op_code, compressor_id - def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None: + def _resolve_pending(self, exc: Optional[Exception] = None) -> None: pending = list(self._pending_messages) for msg in pending: if not msg.done(): @@ -670,21 +476,130 @@ def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None: msg.set_exception(exc) self._done_messages.append(msg) - def close(self, exc: Optional[Exception] = None) -> None: - self.transport.abort() - self._resolve_pending_messages(exc) - self._connection_lost = True - def connection_lost(self, exc: Optional[Exception] = None) -> None: - self._resolve_pending_messages(exc) - if not self._closed.done(): - self._closed.set_result(None) +@dataclass +class KMSBuffer: + buffer: memoryview + start_index: int + end_index: int - async def wait_closed(self) -> None: - await self._closed +class PyMongoKMSProtocol(PyMongoBaseProtocol): + def __init__(self, timeout: Optional[float] = None): + super().__init__(timeout) + self._buffers: collections.deque[KMSBuffer] = collections.deque() + self._bytes_ready = 0 + self._pending_reads: collections.deque[int] = collections.deque() + self._pending_listeners: collections.deque[Future[Any]] = collections.deque() + + def connection_made(self, transport: BaseTransport) -> None: + """Called exactly once when a connection is made. + The transport argument is the transport representing the write side of the connection. + """ + self.transport = transport # type: ignore[assignment] + + async def read(self, bytes_needed: int) -> bytes: + """Read up to the requested bytes from this connection.""" + # Note: all reads are "up-to" bytes_needed because we don't know if the kms_context + # has processed a Content-Length header and is requesting a response or not. + # Wait for other listeners first. + if len(self._pending_listeners): + await asyncio.gather(*self._pending_listeners) + # If there are bytes ready, then there is no need to wait further. + if self._bytes_ready > 0: + return self._read(bytes_needed) + if self.transport: + try: + self.transport.resume_reading() + # Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322 + except AttributeError: + raise OSError("connection is already closed") from None + if self.transport and self.transport.is_closing(): + raise OSError("connection is already closed") + self._pending_reads.append(bytes_needed) + read_waiter = asyncio.get_running_loop().create_future() + self._pending_listeners.append(read_waiter) + return await read_waiter + + def get_buffer(self, sizehint: int) -> memoryview: + """Called to allocate a new receive buffer. + The asyncio loop calls this method expecting to receive a non-empty buffer to fill with data. + If any data does not fit into the returned buffer, this method will be called again until + either no data remains or an empty buffer is returned. + """ + # Reuse the active buffer if it has space. + # Allocate a bit more than the max response size for an AWS KMS response. + sizehint = max(sizehint, 16384) + if len(self._buffers): + buffer = self._buffers[-1] + if len(buffer.buffer) - buffer.end_index > sizehint: + return buffer.buffer[buffer.end_index :] + buffer = KMSBuffer(memoryview(bytearray(sizehint)), 0, 0) + self._buffers.append(buffer) + return buffer.buffer + + def _resolve_pending(self, exc: Optional[Exception] = None) -> None: + while self._pending_listeners: + fut = self._pending_listeners.popleft() + fut.set_result(b"") + + def buffer_updated(self, nbytes: int) -> None: + """Called when the buffer was updated with the received data""" + # Wrote 0 bytes into a non-empty buffer, signal connection closed + if nbytes == 0: + self.close(OSError("connection closed")) + return + if self._connection_lost: + return + self._bytes_ready += nbytes + + # Update the length of the current buffer. + self._buffers[-1].end_index += nbytes + + if not len(self._pending_reads): + return -async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None: + bytes_needed = self._pending_reads.popleft() + data = self._read(bytes_needed) + waiter = self._pending_listeners.popleft() + waiter.set_result(data) + + def _read(self, bytes_needed: int) -> memoryview: + """Read bytes from the buffer.""" + # Send the bytes to the listener. + if self._bytes_ready < bytes_needed: + bytes_needed = self._bytes_ready + self._bytes_ready -= bytes_needed + + output_buf = bytearray(bytes_needed) + n_remaining = bytes_needed + out_index = 0 + while n_remaining > 0: + buffer = self._buffers.popleft() + buffer_remaining = buffer.end_index - buffer.start_index + # if we didn't exhaust the buffer, read the partial data and return the buffer. + if buffer_remaining > n_remaining: + output_buf[out_index : n_remaining + out_index] = buffer.buffer[ + buffer.start_index : buffer.start_index + n_remaining + ] + buffer.start_index += n_remaining + n_remaining = 0 + self._buffers.appendleft(buffer) + # otherwise exhaust the buffer. + else: + output_buf[out_index : out_index + buffer_remaining] = buffer.buffer[ + buffer.start_index : buffer.end_index + ] + out_index += buffer_remaining + n_remaining -= buffer_remaining + # if this is the only buffer, add it back to the queue. + if not len(self._buffers): + buffer.start_index = buffer.end_index + self._buffers.appendleft(buffer) + return memoryview(output_buf) + + +async def async_sendall(conn: PyMongoBaseProtocol, buf: bytes) -> None: try: await asyncio.wait_for(conn.write(buf), timeout=conn.gettimeout) except asyncio.TimeoutError as exc: @@ -692,12 +607,18 @@ async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None: raise socket.timeout("timed out") from exc -async def async_receive_message( - conn: AsyncConnection, - request_id: Optional[int], - max_message_size: int = MAX_MESSAGE_SIZE, -) -> Union[_OpReply, _OpMsg]: - """Receive a raw BSON message or raise socket.error.""" +async def async_receive_kms(conn: AsyncBaseConnection, bytes_needed: int) -> bytes: + """Receive raw bytes from the kms connection.""" + + def callback(result: Any) -> bytes: + return result + + return await _async_receive_data(conn, callback, bytes_needed) + + +async def _async_receive_data( + conn: AsyncBaseConnection, callback: Callable[..., Any], *args: Any +) -> Any: timeout: Optional[Union[float, int]] timeout = conn.conn.gettimeout if _csot.get_timeout(): @@ -713,8 +634,8 @@ async def async_receive_message( # timeouts on AWS Lambda and other FaaS environments. timeout = max(deadline - time.monotonic(), 0) + read_task = create_task(conn.conn.get_conn.read(*args)) cancellation_task = create_task(_poll_cancellation(conn)) - read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size)) tasks = [read_task, cancellation_task] try: done, pending = await asyncio.wait( @@ -727,14 +648,7 @@ async def async_receive_message( if len(done) == 0: raise socket.timeout("timed out") if read_task in done: - data, op_code = read_task.result() - try: - unpack_reply = _UNPACK_REPLY[op_code] - except KeyError: - raise ProtocolError( - f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" - ) from None - return unpack_reply(data) + return callback(read_task.result()) raise _OperationCancelled("operation cancelled") except asyncio.CancelledError: for task in tasks: @@ -743,6 +657,31 @@ async def async_receive_message( raise +async def async_receive_message( + conn: AsyncConnection, + request_id: Optional[int], + max_message_size: int = MAX_MESSAGE_SIZE, +) -> Union[_OpReply, _OpMsg]: + """Receive a raw BSON message or raise socket.error.""" + + def callback(result: Any) -> _OpMsg | _OpReply: + data, op_code = result + try: + unpack_reply = _UNPACK_REPLY[op_code] + except KeyError: + raise ProtocolError( + f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" + ) from None + return unpack_reply(data) + + return await _async_receive_data(conn, callback, request_id, max_message_size) + + +def receive_kms(conn: BaseConnection, bytes_needed: int) -> bytes: + """Receive raw bytes from the kms connection.""" + return conn.conn.sock.recv(bytes_needed) + + def receive_message( conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE ) -> Union[_OpReply, _OpMsg]: diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py index 905f1a4d18..f28226b791 100644 --- a/pymongo/pool_shared.py +++ b/pymongo/pool_shared.py @@ -16,7 +16,6 @@ from __future__ import annotations import asyncio -import functools import socket import ssl import sys @@ -25,7 +24,6 @@ Any, NoReturn, Optional, - Union, ) from pymongo import _csot @@ -36,13 +34,17 @@ NetworkTimeout, _CertificateError, ) -from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol +from pymongo.network_layer import ( + AsyncNetworkingInterface, + NetworkingInterface, + PyMongoBaseProtocol, + PyMongoProtocol, +) from pymongo.pool_options import PoolOptions from pymongo.ssl_support import PYSSLError, SSLError, _has_sni SSLErrors = (PYSSLError, SSLError) if TYPE_CHECKING: - from pymongo.pyopenssl_context import _sslConn from pymongo.typings import _Address try: @@ -269,64 +271,10 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s raise OSError("getaddrinfo failed") -async def _async_configured_socket( - address: _Address, options: PoolOptions -) -> Union[socket.socket, _sslConn]: - """Given (host, port) and PoolOptions, return a raw configured socket. - - Can raise socket.error, ConnectionFailure, or _CertificateError. - - Sets socket's SSL and timeout options. - """ - sock = await _async_create_connection(address, options) - ssl_context = options._ssl_context - - if ssl_context is None: - sock.settimeout(options.socket_timeout) - return sock - - host = address[0] - try: - # We have to pass hostname / ip address to wrap_socket - # to use SSLContext.check_hostname. - if _has_sni(False): - loop = asyncio.get_running_loop() - ssl_sock = await loop.run_in_executor( - None, - functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc, unused-ignore] - ) - else: - loop = asyncio.get_running_loop() - ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc, unused-ignore] - except _CertificateError: - sock.close() - # Raise _CertificateError directly like we do after match_hostname - # below. - raise - except (OSError, *SSLErrors) as exc: - sock.close() - # We raise AutoReconnect for transient and permanent SSL handshake - # failures alike. Permanent handshake failures, like protocol - # mismatch, will be turned into ServerSelectionTimeoutErrors later. - details = _get_timeout_details(options) - _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: - ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore] - except _CertificateError: - ssl_sock.close() - raise - - ssl_sock.settimeout(options.socket_timeout) - return ssl_sock - - async def _configured_protocol_interface( - address: _Address, options: PoolOptions + address: _Address, + options: PoolOptions, + protocol_kls: type[PyMongoBaseProtocol] = PyMongoProtocol, ) -> AsyncNetworkingInterface: """Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface. @@ -341,7 +289,7 @@ async def _configured_protocol_interface( if ssl_context is None: return AsyncNetworkingInterface( await asyncio.get_running_loop().create_connection( - lambda: PyMongoProtocol(timeout=timeout), sock=sock + lambda: protocol_kls(timeout=timeout), sock=sock ) ) @@ -350,7 +298,7 @@ async def _configured_protocol_interface( # We have to pass hostname / ip address to wrap_socket # to use SSLContext.check_hostname. transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload] - lambda: PyMongoProtocol(timeout=timeout), + lambda: protocol_kls(timeout=timeout), sock=sock, server_hostname=host, ssl=ssl_context, @@ -450,56 +398,9 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket raise OSError("getaddrinfo failed") -def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]: - """Given (host, port) and PoolOptions, return a raw configured socket. - - Can raise socket.error, ConnectionFailure, or _CertificateError. - - Sets socket's SSL and timeout options. - """ - sock = _create_connection(address, options) - ssl_context = options._ssl_context - - if ssl_context is None: - sock.settimeout(options.socket_timeout) - return sock - - host = address[0] - try: - # We have to pass hostname / ip address to wrap_socket - # to use SSLContext.check_hostname. - if _has_sni(True): - ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore] - else: - ssl_sock = ssl_context.wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore] - except _CertificateError: - sock.close() - # Raise _CertificateError directly like we do after match_hostname - # below. - raise - except (OSError, *SSLErrors) as exc: - sock.close() - # We raise AutoReconnect for transient and permanent SSL handshake - # failures alike. Permanent handshake failures, like protocol - # mismatch, will be turned into ServerSelectionTimeoutErrors later. - details = _get_timeout_details(options) - _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: - ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore] - except _CertificateError: - ssl_sock.close() - raise - - ssl_sock.settimeout(options.socket_timeout) - return ssl_sock - - -def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface: +def _configured_socket_interface( + address: _Address, options: PoolOptions, *args: Any +) -> NetworkingInterface: """Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket. Can raise socket.error, ConnectionFailure, or _CertificateError. diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index ba304e7bd3..7c8d95bba6 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -70,11 +70,11 @@ NetworkTimeout, ServerSelectionTimeoutError, ) -from pymongo.network_layer import sendall +from pymongo.network_layer import PyMongoKMSProtocol, receive_kms, sendall from pymongo.operations import UpdateOne from pymongo.pool_options import PoolOptions from pymongo.pool_shared import ( - _configured_socket, + _configured_socket_interface, _get_timeout_details, _raise_connection_failure, ) @@ -85,6 +85,7 @@ from pymongo.synchronous.cursor import Cursor from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient +from pymongo.synchronous.pool import BaseConnection from pymongo.typings import _DocumentType, _DocumentTypeArg from pymongo.uri_parser_shared import _parse_kms_tls_options, parse_host from pymongo.write_concern import WriteConcern @@ -92,10 +93,8 @@ if TYPE_CHECKING: from pymongocrypt.mongocrypt import MongoCryptKmsContext - from pymongo.pyopenssl_context import _sslConn from pymongo.typings import _Address - _IS_SYNC = True _HTTPS_PORT = 443 @@ -110,9 +109,10 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument) -def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]: +def _connect_kms(address: _Address, opts: PoolOptions) -> BaseConnection: try: - return _configured_socket(address, opts) + interface = _configured_socket_interface(address, opts, PyMongoKMSProtocol) + return BaseConnection(interface, opts) except Exception as exc: _raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts)) @@ -197,18 +197,11 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: try: conn = _connect_kms(address, opts) try: - sendall(conn, message) + sendall(conn.conn.get_conn, message) while kms_context.bytes_needed > 0: # CSOT: update timeout. - conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - if _IS_SYNC: - data = conn.recv(kms_context.bytes_needed) - else: - from pymongo.network_layer import ( # type: ignore[attr-defined] - receive_data_socket, - ) - - data = receive_data_socket(conn, kms_context.bytes_needed) + conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + data = receive_kms(conn, kms_context.bytes_needed) if not data: raise OSError("KMS connection closed") kms_context.feed(data) @@ -227,7 +220,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts) ) finally: - conn.close() + conn.close_conn(None) except MongoCryptError: raise # Propagate MongoCryptError errors directly. except Exception as exc: diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 4ea5cb1c1e..434c70d288 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -124,7 +124,89 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001 _IS_SYNC = True -class Connection: +class BaseConnection: + """A base connection object for server and kms connections.""" + + def __init__(self, conn: NetworkingInterface, opts: PoolOptions): + self.conn = conn + self.socket_checker: SocketChecker = SocketChecker() + self.cancel_context: _CancellationContext = _CancellationContext() + self.is_sdam = False + self.closed = False + self.last_timeout: float | None = None + self.more_to_come = False + self.opts = opts + self.max_wire_version = -1 + + def set_conn_timeout(self, timeout: Optional[float]) -> None: + """Cache last timeout to avoid duplicate calls to conn.settimeout.""" + if timeout == self.last_timeout: + return + self.last_timeout = timeout + self.conn.get_conn.settimeout(timeout) + + def apply_timeout( + self, client: MongoClient[Any], cmd: Optional[MutableMapping[str, Any]] + ) -> Optional[float]: + # CSOT: use remaining timeout when set. + timeout = _csot.remaining() + if timeout is None: + # Reset the socket timeout unless we're performing a streaming monitor check. + if not self.more_to_come: + self.set_conn_timeout(self.opts.socket_timeout) + return None + # RTT validation. + rtt = _csot.get_rtt() + if rtt is None: + rtt = self.connect_rtt + max_time_ms = timeout - rtt + if max_time_ms < 0: + timeout_details = _get_timeout_details(self.opts) + formatted = format_timeout_details(timeout_details) + # CSOT: raise an error without running the command since we know it will time out. + errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}" + if self.max_wire_version != -1: + raise ExecutionTimeout( + errmsg, + 50, + {"ok": 0, "errmsg": errmsg, "code": 50}, + self.max_wire_version, + ) + else: + raise TimeoutError(errmsg) + if cmd is not None: + cmd["maxTimeMS"] = int(max_time_ms * 1000) + self.set_conn_timeout(timeout) + return timeout + + def close_conn(self, reason: Optional[str]) -> None: + """Close this connection with a reason.""" + if self.closed: + return + self._close_conn() + + def _close_conn(self) -> None: + """Close this connection.""" + if self.closed: + return + self.closed = True + self.cancel_context.cancel() + # Note: We catch exceptions to avoid spurious errors on interpreter + # shutdown. + try: + self.conn.close() + except Exception: # noqa: S110 + pass + + def conn_closed(self) -> bool: + """Return True if we know socket has been closed, False otherwise.""" + if _IS_SYNC: + return self.socket_checker.socket_closed(self.conn.get_conn) + else: + return self.conn.is_closing() + + +class Connection(BaseConnection): """Store a connection with some metadata. :param conn: a raw connection object @@ -142,29 +224,27 @@ def __init__( id: int, is_sdam: bool, ): + super().__init__(conn, pool.opts) self.pool_ref = weakref.ref(pool) - self.conn = conn - self.address = address - self.id = id + self.address: tuple[str, int] = address + self.id: int = id self.is_sdam = is_sdam - self.closed = False self.last_checkin_time = time.monotonic() self.performed_handshake = False self.is_writable: bool = False self.max_wire_version = MAX_WIRE_VERSION - self.max_bson_size = MAX_BSON_SIZE - self.max_message_size = MAX_MESSAGE_SIZE - self.max_write_batch_size = MAX_WRITE_BATCH_SIZE + self.max_bson_size: int = MAX_BSON_SIZE + self.max_message_size: int = MAX_MESSAGE_SIZE + self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE self.supports_sessions = False self.hello_ok: bool = False - self.is_mongos = False + self.is_mongos: bool = False self.op_msg_enabled = False self.listeners = pool.opts._event_listeners self.enabled_for_cmap = pool.enabled_for_cmap self.enabled_for_logging = pool.enabled_for_logging self.compression_settings = pool.opts._compression_settings self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None - self.socket_checker: SocketChecker = SocketChecker() self.oidc_token_gen_id: Optional[int] = None # Support for mechanism negotiation on the initial handshake. self.negotiated_mechs: Optional[list[str]] = None @@ -175,9 +255,6 @@ def __init__( self.pool_gen = pool.gen self.generation = self.pool_gen.get_overall() self.ready = False - self.cancel_context: _CancellationContext = _CancellationContext() - self.opts = pool.opts - self.more_to_come: bool = False # For load balancer support. self.service_id: Optional[ObjectId] = None self.server_connection_id: Optional[int] = None @@ -193,44 +270,6 @@ def __init__( # For gossiping $clusterTime from the connection handshake to the client. self._cluster_time = None - def set_conn_timeout(self, timeout: Optional[float]) -> None: - """Cache last timeout to avoid duplicate calls to conn.settimeout.""" - if timeout == self.last_timeout: - return - self.last_timeout = timeout - self.conn.get_conn.settimeout(timeout) - - def apply_timeout( - self, client: MongoClient[Any], cmd: Optional[MutableMapping[str, Any]] - ) -> Optional[float]: - # CSOT: use remaining timeout when set. - timeout = _csot.remaining() - if timeout is None: - # Reset the socket timeout unless we're performing a streaming monitor check. - if not self.more_to_come: - self.set_conn_timeout(self.opts.socket_timeout) - return None - # RTT validation. - rtt = _csot.get_rtt() - if rtt is None: - rtt = self.connect_rtt - max_time_ms = timeout - rtt - if max_time_ms < 0: - timeout_details = _get_timeout_details(self.opts) - formatted = format_timeout_details(timeout_details) - # CSOT: raise an error without running the command since we know it will time out. - errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}" - raise ExecutionTimeout( - errmsg, - 50, - {"ok": 0, "errmsg": errmsg, "code": 50}, - self.max_wire_version, - ) - if cmd is not None: - cmd["maxTimeMS"] = int(max_time_ms * 1000) - self.set_conn_timeout(timeout) - return timeout - def pin_txn(self) -> None: self.pinned_txn = True assert not self.pinned_cursor @@ -572,26 +611,6 @@ def close_conn(self, reason: Optional[str]) -> None: error=reason, ) - def _close_conn(self) -> None: - """Close this connection.""" - if self.closed: - return - self.closed = True - self.cancel_context.cancel() - # Note: We catch exceptions to avoid spurious errors on interpreter - # shutdown. - try: - self.conn.close() - except Exception: # noqa: S110 - pass - - def conn_closed(self) -> bool: - """Return True if we know socket has been closed, False otherwise.""" - if _IS_SYNC: - return self.socket_checker.socket_closed(self.conn.get_conn) - else: - return self.conn.is_closing() - def send_cluster_time( self, command: MutableMapping[str, Any], diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index cda8452d1c..6a85b63960 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -335,6 +335,8 @@ async def test_create_index(self): await db.test.create_index(["hello", ("world", DESCENDING)]) await db.test.create_index({"hello": 1}.items()) # type:ignore[arg-type] + # TODO: PYTHON-5491 - remove version max + @async_client_context.require_version_max(8, 0, -1) async def test_drop_index(self): db = self.db await db.test.drop_indexes() diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 964d2df96d..09bf7e83ea 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -564,6 +564,8 @@ def maybe_skip_test(self, spec): self.skipTest("CSOT not implemented for watch()") if "cursors" in class_name: self.skipTest("CSOT not implemented for cursors") + if "dropindex on collection" in description: + self.skipTest("PYTHON-5491") if ( "tailable" in class_name or "tailable" in description diff --git a/test/test_collection.py b/test/test_collection.py index ccace72bec..0dce88423b 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -333,6 +333,8 @@ def test_create_index(self): db.test.create_index(["hello", ("world", DESCENDING)]) db.test.create_index({"hello": 1}.items()) # type:ignore[arg-type] + # TODO: PYTHON-5491 - remove version max + @client_context.require_version_max(8, 0, -1) def test_drop_index(self): db = self.db db.test.drop_indexes() diff --git a/test/unified_format.py b/test/unified_format.py index c21f29fe19..3496b2ad44 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -563,6 +563,8 @@ def maybe_skip_test(self, spec): self.skipTest("CSOT not implemented for watch()") if "cursors" in class_name: self.skipTest("CSOT not implemented for cursors") + if "dropindex on collection" in description: + self.skipTest("PYTHON-5491") if ( "tailable" in class_name or "tailable" in description diff --git a/tools/synchro.py b/tools/synchro.py index e502f96281..9a760c0ad7 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -120,9 +120,9 @@ "_async_create_lock": "_create_lock", "_async_create_condition": "_create_condition", "_async_cond_wait": "_cond_wait", + "async_receive_kms": "receive_kms", "AsyncNetworkingInterface": "NetworkingInterface", "_configured_protocol_interface": "_configured_socket_interface", - "_async_configured_socket": "_configured_socket", "SpecRunnerTask": "SpecRunnerThread", "AsyncMockConnection": "MockConnection", "AsyncMockPool": "MockPool",