diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index beffba6d18..4d21300bc6 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -28,7 +28,7 @@ Union, ) -from pymongo import ssl_support +from pymongo import _csot, ssl_support from pymongo._asyncio_task import create_task from pymongo.errors import _OperationCancelled from pymongo.socket_checker import _errno_from_exception @@ -316,6 +316,42 @@ async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLo return mv +_PYPY = "PyPy" in sys.version + + +def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: + """Block until at least one byte is read, or a timeout, or a cancel.""" + sock = conn.conn + timed_out = False + # Check if the connection's socket has been manually closed + if sock.fileno() == -1: + return + while True: + # SSLSocket can have buffered data which won't be caught by select. + if hasattr(sock, "pending") and sock.pending() > 0: + readable = True + else: + # Wait up to 500ms for the socket to become readable and then + # check for cancellation. + if deadline: + remaining = deadline - time.monotonic() + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + if remaining <= 0: + timed_out = True + timeout = max(min(remaining, _POLL_TIMEOUT), 0) + else: + timeout = _POLL_TIMEOUT + readable = conn.socket_checker.select(sock, read=True, timeout=timeout) + if conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") + if readable: + return + if timed_out: + raise socket.timeout("timed out") + + def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: buf = bytearray(length) mv = memoryview(buf) @@ -324,18 +360,25 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me # check for the cancellation signal after each timeout. Alternatively we # could close the socket but that does not reliably cancel recv() calls # on all OSes. + # When the timeout has expired we perform one final non-blocking recv. + # This helps avoid spurious timeouts when the response is actually already + # buffered on the client. orig_timeout = conn.conn.gettimeout() try: while bytes_read < length: - if deadline is not None: - # CSOT: Update timeout. When the timeout has expired perform one - # final non-blocking recv. This helps avoid spurious timeouts when - # the response is actually already buffered on the client. - short_timeout = min(max(deadline - time.monotonic(), 0), _POLL_TIMEOUT) - else: - short_timeout = _POLL_TIMEOUT - conn.set_conn_timeout(short_timeout) try: + # Use the legacy wait_for_read cancellation approach on PyPy due to PYTHON-5011. + if _PYPY: + wait_for_read(conn, deadline) + if _csot.get_timeout() and deadline is not None: + conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) + else: + if deadline is not None: + short_timeout = min(max(deadline - time.monotonic(), 0), _POLL_TIMEOUT) + else: + short_timeout = _POLL_TIMEOUT + conn.set_conn_timeout(short_timeout) + chunk_length = conn.conn.recv_into(mv[bytes_read:]) except BLOCKING_IO_ERRORS: if conn.cancel_context.cancelled: @@ -345,6 +388,9 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me except socket.timeout: if conn.cancel_context.cancelled: raise _OperationCancelled("operation cancelled") from None + if _PYPY: + # We reached the true deadline. + raise continue except OSError as exc: if conn.cancel_context.cancelled: