Skip to content

Commit 482485d

Browse files
committed
Sync tests all passing
1 parent 574c0ec commit 482485d

File tree

14 files changed

+139
-494
lines changed

14 files changed

+139
-494
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
from pymongo.asynchronous.cursor import AsyncCursor
6464
from pymongo.asynchronous.database import AsyncDatabase
6565
from pymongo.asynchronous.mongo_client import AsyncMongoClient
66-
from pymongo.asynchronous.pool import _configured_socket, _raise_connection_failure
6766
from pymongo.common import CONNECT_TIMEOUT
6867
from pymongo.daemon import _spawn_daemon
6968
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
@@ -75,12 +74,13 @@
7574
PyMongoError,
7675
ServerSelectionTimeoutError,
7776
)
78-
from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall
77+
from pymongo.network_layer import async_sendall
7978
from pymongo.operations import UpdateOne
8079
from pymongo.pool_options import PoolOptions
80+
from pymongo.pool_shared import _configured_socket, _raise_connection_failure
8181
from pymongo.read_concern import ReadConcern
8282
from pymongo.results import BulkWriteResult, DeleteResult
83-
from pymongo.ssl_support import get_ssl_context
83+
from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context
8484
from pymongo.typings import _DocumentType, _DocumentTypeArg
8585
from pymongo.uri_parser import parse_host
8686
from pymongo.write_concern import WriteConcern

pymongo/asynchronous/network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ async def command(
189189
)
190190

191191
try:
192-
await async_sendall(conn.conn.writer, msg)
192+
await async_sendall(conn.conn.get_conn, msg)
193193
if use_op_msg and unacknowledged:
194194
# Unacknowledged, fake a successful command response.
195195
reply = None

pymongo/asynchronous/pool.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@
1717
import asyncio
1818
import collections
1919
import contextlib
20-
import functools
2120
import logging
2221
import os
23-
import socket
24-
import ssl
2522
import sys
2623
import time
2724
import weakref
@@ -52,16 +49,13 @@
5249
from pymongo.errors import ( # type:ignore[attr-defined]
5350
AutoReconnect,
5451
ConfigurationError,
55-
ConnectionFailure,
5652
DocumentTooLarge,
5753
ExecutionTimeout,
5854
InvalidOperation,
59-
NetworkTimeout,
6055
NotPrimaryError,
6156
OperationFailure,
6257
PyMongoError,
6358
WaitQueueTimeoutError,
64-
_CertificateError,
6559
)
6660
from pymongo.hello import Hello, HelloCompat
6761
from pymongo.lock import (
@@ -79,10 +73,15 @@
7973
ConnectionCheckOutFailedReason,
8074
ConnectionClosedReason,
8175
)
82-
from pymongo.network_layer import async_receive_message, async_sendall, AsyncNetworkingInterface
76+
from pymongo.network_layer import AsyncNetworkingInterface, async_receive_message, async_sendall
8377
from pymongo.pool_options import PoolOptions
84-
from pymongo.pool_shared import _configured_protocol, _CancellationContext, _get_timeout_details, format_timeout_details, \
85-
_raise_connection_failure
78+
from pymongo.pool_shared import (
79+
_CancellationContext,
80+
_configured_protocol,
81+
_get_timeout_details,
82+
_raise_connection_failure,
83+
format_timeout_details,
84+
)
8685
from pymongo.read_preferences import ReadPreference
8786
from pymongo.server_api import _add_to_command
8887
from pymongo.server_type import SERVER_TYPE
@@ -101,7 +100,6 @@
101100
ZstdContext,
102101
)
103102
from pymongo.message import _OpMsg, _OpReply
104-
from pymongo.pyopenssl_context import _sslConn
105103
from pymongo.read_concern import ReadConcern
106104
from pymongo.read_preferences import _ServerMode
107105
from pymongo.typings import ClusterTime, _Address, _CollationIn
@@ -195,6 +193,7 @@ def set_conn_timeout(self, timeout: Optional[float]) -> None:
195193
if timeout == self.last_timeout:
196194
return
197195
self.last_timeout = timeout
196+
self.conn.get_conn.settimeout(timeout)
198197

199198
def apply_timeout(
200199
self, client: AsyncMongoClient, cmd: Optional[MutableMapping[str, Any]]
@@ -453,7 +452,7 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None:
453452
)
454453

455454
try:
456-
await async_sendall(self.conn.writer, message)
455+
await async_sendall(self.conn.get_conn, message)
457456
except BaseException as error:
458457
self._raise_connection_failure(error)
459458

@@ -589,7 +588,10 @@ def _close_conn(self) -> None:
589588

590589
def conn_closed(self) -> bool:
591590
"""Return True if we know socket has been closed, False otherwise."""
592-
return self.conn.is_closing()
591+
if _IS_SYNC:
592+
return self.socket_checker.socket_closed(self.conn.get_conn)
593+
else:
594+
return self.conn.is_closing()
593595

594596
def send_cluster_time(
595597
self,
@@ -977,9 +979,7 @@ async def remove_stale_sockets(self, reference_generation: int) -> None:
977979
self.requests -= 1
978980
self.size_cond.notify()
979981

980-
async def connect(
981-
self, handler: Optional[_MongoClientErrorHandler] = None
982-
) -> AsyncConnection:
982+
async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection:
983983
"""Connect to Mongo and return a new AsyncConnection.
984984
985985
Can raise ConnectionFailure.

pymongo/network_layer.py

Lines changed: 44 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
Union,
2727
)
2828

29-
from pymongo import _csot
29+
from pymongo import _csot, ssl_support
3030
from pymongo._asyncio_task import create_task
3131
from pymongo.common import MAX_MESSAGE_SIZE
3232
from pymongo.compression_support import decompress
@@ -59,7 +59,9 @@
5959

6060

6161
class NetworkingInterfaceBase:
62-
def __init__(self, conn: Union[socket.socket, _sslConn] | tuple[asyncio.BaseTransport, PyMongoProtocol]):
62+
def __init__(
63+
self, conn: Union[socket.socket, _sslConn] | tuple[asyncio.BaseTransport, PyMongoProtocol]
64+
):
6365
self.conn = conn
6466

6567
def gettimeout(self):
@@ -74,10 +76,7 @@ def close(self):
7476
def is_closing(self) -> bool:
7577
raise NotImplementedError
7678

77-
def writer(self):
78-
raise NotImplementedError
79-
80-
def reader(self):
79+
def get_conn(self):
8180
raise NotImplementedError
8281

8382

@@ -99,11 +98,7 @@ def is_closing(self):
9998
self.conn[0].is_closing()
10099

101100
@property
102-
def writer(self) -> PyMongoProtocol:
103-
return self.conn[1]
104-
105-
@property
106-
def reader(self) -> PyMongoProtocol:
101+
def get_conn(self) -> PyMongoProtocol:
107102
return self.conn[1]
108103

109104

@@ -124,11 +119,7 @@ def is_closing(self):
124119
self.conn.is_closing()
125120

126121
@property
127-
def writer(self):
128-
return self.conn
129-
130-
@property
131-
def reader(self):
122+
def get_conn(self):
132123
return self.conn
133124

134125

@@ -333,35 +324,8 @@ async def _poll_cancellation(conn: AsyncConnection) -> None:
333324
await asyncio.sleep(_POLL_TIMEOUT)
334325

335326

336-
async def async_receive_data(
337-
conn: AsyncConnection,
338-
deadline: Optional[float],
339-
request_id: Optional[int],
340-
max_message_size: int,
341-
) -> memoryview:
342-
conn_timeout = conn.conn.gettimeout
343-
timeout: Optional[Union[float, int]]
344-
if deadline:
345-
# When the timeout has expired perform one final check to
346-
# see if the socket is readable. This helps avoid spurious
347-
# timeouts on AWS Lambda and other FaaS environments.
348-
timeout = max(deadline - time.monotonic(), 0)
349-
else:
350-
timeout = conn_timeout
351-
352-
cancellation_task = create_task(_poll_cancellation(conn))
353-
read_task = create_task(conn.conn.reader.read(request_id, max_message_size))
354-
tasks = [read_task, cancellation_task]
355-
done, pending = await asyncio.wait(tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
356-
for task in pending:
357-
task.cancel()
358-
if pending:
359-
await asyncio.wait(pending)
360-
if len(done) == 0:
361-
raise socket.timeout("timed out")
362-
if read_task in done:
363-
return read_task.result()
364-
raise _OperationCancelled("operation cancelled")
327+
# Errors raised by sockets (and TLS sockets) when in non-blocking mode.
328+
BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS)
365329

366330

367331
def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
@@ -384,12 +348,12 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me
384348
short_timeout = _POLL_TIMEOUT
385349
conn.set_conn_timeout(short_timeout)
386350
try:
387-
chunk_length = conn.conn.recv_into(mv[bytes_read:])
388-
# except BLOCKING_IO_ERRORS:
389-
# if conn.cancel_context.cancelled:
390-
# raise _OperationCancelled("operation cancelled") from None
391-
# # We reached the true deadline.
392-
# raise socket.timeout("timed out") from None
351+
chunk_length = conn.conn.get_conn.recv_into(mv[bytes_read:])
352+
except BLOCKING_IO_ERRORS:
353+
if conn.cancel_context.cancelled:
354+
raise _OperationCancelled("operation cancelled") from None
355+
# We reached the true deadline.
356+
raise socket.timeout("timed out") from None
393357
except socket.timeout:
394358
if conn.cancel_context.cancelled:
395359
raise _OperationCancelled("operation cancelled") from None
@@ -416,22 +380,42 @@ async def async_receive_message(
416380
max_message_size: int = MAX_MESSAGE_SIZE,
417381
) -> Union[_OpReply, _OpMsg]:
418382
"""Receive a raw BSON message or raise socket.error."""
383+
timeout: Optional[Union[float, int]]
419384
if _csot.get_timeout():
420385
deadline = _csot.get_deadline()
421386
else:
422-
timeout = conn.conn.reader.gettimeout
387+
timeout = conn.conn.get_conn.gettimeout
423388
if timeout:
424389
deadline = time.monotonic() + timeout
425390
else:
426391
deadline = None
427-
data, op_code = await async_receive_data(conn, deadline, request_id, max_message_size)
428-
try:
429-
unpack_reply = _UNPACK_REPLY[op_code]
430-
except KeyError:
431-
raise ProtocolError(
432-
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
433-
) from None
434-
return unpack_reply(data)
392+
if deadline:
393+
# When the timeout has expired perform one final check to
394+
# see if the socket is readable. This helps avoid spurious
395+
# timeouts on AWS Lambda and other FaaS environments.
396+
timeout = max(deadline - time.monotonic(), 0)
397+
398+
cancellation_task = create_task(_poll_cancellation(conn))
399+
read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size))
400+
tasks = [read_task, cancellation_task]
401+
done, pending = await asyncio.wait(tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
402+
for task in pending:
403+
task.cancel()
404+
if pending:
405+
await asyncio.wait(pending)
406+
if len(done) == 0:
407+
raise socket.timeout("timed out")
408+
if read_task in done:
409+
data, op_code = read_task.result()
410+
411+
try:
412+
unpack_reply = _UNPACK_REPLY[op_code]
413+
except KeyError:
414+
raise ProtocolError(
415+
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
416+
) from None
417+
return unpack_reply(data)
418+
raise _OperationCancelled("operation cancelled")
435419

436420

437421
def receive_message(

pymongo/pool_shared.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,20 @@
99
Any,
1010
NoReturn,
1111
Optional,
12-
Union,
1312
)
1413

1514
from pymongo import _csot
1615
from pymongo.errors import ( # type:ignore[attr-defined]
1716
AutoReconnect,
18-
ConfigurationError,
1917
ConnectionFailure,
20-
DocumentTooLarge,
21-
ExecutionTimeout,
22-
InvalidOperation,
2318
NetworkTimeout,
24-
NotPrimaryError,
25-
OperationFailure,
26-
PyMongoError,
27-
WaitQueueTimeoutError,
2819
_CertificateError,
2920
)
30-
from pymongo.network_layer import PyMongoProtocol, AsyncNetworkingInterface
21+
from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol
3122
from pymongo.pool_options import PoolOptions
3223
from pymongo.ssl_support import HAS_SNI, SSLError
3324

3425
if TYPE_CHECKING:
35-
from pymongo.pyopenssl_context import _sslConn
3626
from pymongo.typings import _Address
3727

3828
try:
@@ -50,6 +40,7 @@ def _set_non_inheritable_non_atomic(fd: int) -> None:
5040
def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
5141
"""Dummy function for platforms that don't provide fcntl."""
5242

43+
5344
_MAX_TCP_KEEPIDLE = 120
5445
_MAX_TCP_KEEPINTVL = 10
5546
_MAX_TCP_KEEPCNT = 9
@@ -249,9 +240,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
249240
raise OSError("getaddrinfo failed")
250241

251242

252-
async def _configured_protocol(
253-
address: _Address, options: PoolOptions
254-
) -> AsyncNetworkingInterface:
243+
async def _configured_protocol(address: _Address, options: PoolOptions) -> AsyncNetworkingInterface:
255244
"""Given (host, port) and PoolOptions, return a configured transport, protocol pair.
256245
257246
Can raise socket.error, ConnectionFailure, or _CertificateError.
@@ -263,9 +252,11 @@ async def _configured_protocol(
263252
timeout = sock.gettimeout()
264253

265254
if ssl_context is None:
266-
return AsyncNetworkingInterface(await asyncio.get_running_loop().create_connection(
267-
lambda: PyMongoProtocol(timeout=timeout, buffer_size=2**16), sock=sock
268-
))
255+
return AsyncNetworkingInterface(
256+
await asyncio.get_running_loop().create_connection(
257+
lambda: PyMongoProtocol(timeout=timeout, buffer_size=2**16), sock=sock
258+
)
259+
)
269260

270261
host = address[0]
271262
try:
@@ -303,9 +294,7 @@ async def _configured_protocol(
303294
return AsyncNetworkingInterface((transport, protocol))
304295

305296

306-
def _configured_socket(
307-
address: _Address, options: PoolOptions
308-
) -> Union[socket.socket, _sslConn]:
297+
def _configured_socket(address: _Address, options: PoolOptions) -> NetworkingInterface:
309298
"""Given (host, port) and PoolOptions, return a configured socket.
310299
311300
Can raise socket.error, ConnectionFailure, or _CertificateError.
@@ -317,7 +306,7 @@ def _configured_socket(
317306

318307
if ssl_context is None:
319308
sock.settimeout(options.socket_timeout)
320-
return sock
309+
return NetworkingInterface(sock)
321310

322311
host = address[0]
323312
try:
@@ -351,4 +340,4 @@ def _configured_socket(
351340
raise
352341

353342
ssl_sock.settimeout(options.socket_timeout)
354-
return ssl_sock
343+
return NetworkingInterface(ssl_sock)

0 commit comments

Comments
 (0)