|
28 | 28 | Union,
|
29 | 29 | )
|
30 | 30 |
|
31 |
| -from pymongo import ssl_support |
| 31 | +from pymongo import _csot, ssl_support |
32 | 32 | from pymongo._asyncio_task import create_task
|
33 | 33 | from pymongo.errors import _OperationCancelled
|
34 | 34 | from pymongo.socket_checker import _errno_from_exception
|
@@ -316,47 +316,108 @@ async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLo
|
316 | 316 | return mv
|
317 | 317 |
|
318 | 318 |
|
319 |
| -def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: |
320 |
| - buf = bytearray(length) |
321 |
| - mv = memoryview(buf) |
322 |
| - bytes_read = 0 |
323 |
| - # To support cancelling a network read, we shorten the socket timeout and |
324 |
| - # check for the cancellation signal after each timeout. Alternatively we |
325 |
| - # could close the socket but that does not reliably cancel recv() calls |
326 |
| - # on all OSes. |
327 |
| - orig_timeout = conn.conn.gettimeout() |
328 |
| - try: |
| 319 | +if "PyPy" not in sys.version: |
| 320 | + |
| 321 | + def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: |
| 322 | + buf = bytearray(length) |
| 323 | + mv = memoryview(buf) |
| 324 | + bytes_read = 0 |
| 325 | + # To support cancelling a network read, we shorten the socket timeout and |
| 326 | + # check for the cancellation signal after each timeout. Alternatively we |
| 327 | + # could close the socket but that does not reliably cancel recv() calls |
| 328 | + # on all OSes. |
| 329 | + orig_timeout = conn.conn.gettimeout() |
| 330 | + try: |
| 331 | + while bytes_read < length: |
| 332 | + if deadline is not None: |
| 333 | + # CSOT: Update timeout. When the timeout has expired perform one |
| 334 | + # final non-blocking recv. This helps avoid spurious timeouts when |
| 335 | + # the response is actually already buffered on the client. |
| 336 | + short_timeout = min(max(deadline - time.monotonic(), 0), _POLL_TIMEOUT) |
| 337 | + else: |
| 338 | + short_timeout = _POLL_TIMEOUT |
| 339 | + conn.set_conn_timeout(short_timeout) |
| 340 | + try: |
| 341 | + chunk_length = conn.conn.recv_into(mv[bytes_read:]) |
| 342 | + except BLOCKING_IO_ERRORS: |
| 343 | + if conn.cancel_context.cancelled: |
| 344 | + raise _OperationCancelled("operation cancelled") from None |
| 345 | + # We reached the true deadline. |
| 346 | + raise socket.timeout("timed out") from None |
| 347 | + except socket.timeout: |
| 348 | + if conn.cancel_context.cancelled: |
| 349 | + raise _OperationCancelled("operation cancelled") from None |
| 350 | + continue |
| 351 | + except OSError as exc: |
| 352 | + if conn.cancel_context.cancelled: |
| 353 | + raise _OperationCancelled("operation cancelled") from None |
| 354 | + if _errno_from_exception(exc) == errno.EINTR: |
| 355 | + continue |
| 356 | + raise |
| 357 | + if chunk_length == 0: |
| 358 | + raise OSError("connection closed") |
| 359 | + |
| 360 | + bytes_read += chunk_length |
| 361 | + finally: |
| 362 | + conn.set_conn_timeout(orig_timeout) |
| 363 | + |
| 364 | + return mv |
| 365 | +else: |
| 366 | + |
| 367 | + def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: |
| 368 | + """Block until at least one byte is read, or a timeout, or a cancel.""" |
| 369 | + sock = conn.conn |
| 370 | + timed_out = False |
| 371 | + # Check if the connection's socket has been manually closed |
| 372 | + if sock.fileno() == -1: |
| 373 | + return |
| 374 | + while True: |
| 375 | + # SSLSocket can have buffered data which won't be caught by select. |
| 376 | + if hasattr(sock, "pending") and sock.pending() > 0: |
| 377 | + readable = True |
| 378 | + else: |
| 379 | + # Wait up to 500ms for the socket to become readable and then |
| 380 | + # check for cancellation. |
| 381 | + if deadline: |
| 382 | + remaining = deadline - time.monotonic() |
| 383 | + # When the timeout has expired perform one final check to |
| 384 | + # see if the socket is readable. This helps avoid spurious |
| 385 | + # timeouts on AWS Lambda and other FaaS environments. |
| 386 | + if remaining <= 0: |
| 387 | + timed_out = True |
| 388 | + timeout = max(min(remaining, _POLL_TIMEOUT), 0) |
| 389 | + else: |
| 390 | + timeout = _POLL_TIMEOUT |
| 391 | + readable = conn.socket_checker.select(sock, read=True, timeout=timeout) |
| 392 | + if conn.cancel_context.cancelled: |
| 393 | + raise _OperationCancelled("operation cancelled") |
| 394 | + if readable: |
| 395 | + return |
| 396 | + if timed_out: |
| 397 | + raise socket.timeout("timed out") |
| 398 | + |
| 399 | + def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: |
| 400 | + buf = bytearray(length) |
| 401 | + mv = memoryview(buf) |
| 402 | + bytes_read = 0 |
329 | 403 | while bytes_read < length:
|
330 |
| - if deadline is not None: |
| 404 | + try: |
| 405 | + wait_for_read(conn, deadline) |
331 | 406 | # CSOT: Update timeout. When the timeout has expired perform one
|
332 | 407 | # final non-blocking recv. This helps avoid spurious timeouts when
|
333 | 408 | # the response is actually already buffered on the client.
|
334 |
| - short_timeout = min(max(deadline - time.monotonic(), 0), _POLL_TIMEOUT) |
335 |
| - else: |
336 |
| - short_timeout = _POLL_TIMEOUT |
337 |
| - conn.set_conn_timeout(short_timeout) |
338 |
| - try: |
| 409 | + if _csot.get_timeout() and deadline is not None: |
| 410 | + conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) |
339 | 411 | chunk_length = conn.conn.recv_into(mv[bytes_read:])
|
340 | 412 | except BLOCKING_IO_ERRORS:
|
341 |
| - if conn.cancel_context.cancelled: |
342 |
| - raise _OperationCancelled("operation cancelled") from None |
343 |
| - # We reached the true deadline. |
344 | 413 | raise socket.timeout("timed out") from None
|
345 |
| - except socket.timeout: |
346 |
| - if conn.cancel_context.cancelled: |
347 |
| - raise _OperationCancelled("operation cancelled") from None |
348 |
| - continue |
349 | 414 | except OSError as exc:
|
350 |
| - if conn.cancel_context.cancelled: |
351 |
| - raise _OperationCancelled("operation cancelled") from None |
352 | 415 | if _errno_from_exception(exc) == errno.EINTR:
|
353 | 416 | continue
|
354 | 417 | raise
|
355 | 418 | if chunk_length == 0:
|
356 | 419 | raise OSError("connection closed")
|
357 | 420 |
|
358 | 421 | bytes_read += chunk_length
|
359 |
| - finally: |
360 |
| - conn.set_conn_timeout(orig_timeout) |
361 | 422 |
|
362 |
| - return mv |
| 423 | + return mv |
0 commit comments