Skip to content

Commit cb1031b

Browse files
committed
Merge branch 'master' of github.com:mongodb/mongo-python-driver into PYTHON-5143
2 parents d2f141d + e4b7eb5 commit cb1031b

File tree

12 files changed

+471
-617
lines changed

12 files changed

+471
-617
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from pymongo.asynchronous.cursor import AsyncCursor
6565
from pymongo.asynchronous.database import AsyncDatabase
6666
from pymongo.asynchronous.mongo_client import AsyncMongoClient
67+
from pymongo.asynchronous.pool import AsyncBaseConnection
6768
from pymongo.common import CONNECT_TIMEOUT
6869
from pymongo.daemon import _spawn_daemon
6970
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
@@ -75,12 +76,12 @@
7576
NetworkTimeout,
7677
ServerSelectionTimeoutError,
7778
)
78-
from pymongo.network_layer import async_socket_sendall
79+
from pymongo.helpers_shared import _get_timeout_details
80+
from pymongo.network_layer import PyMongoKMSProtocol, async_receive_kms, async_sendall
7981
from pymongo.operations import UpdateOne
8082
from pymongo.pool_options import PoolOptions
8183
from pymongo.pool_shared import (
82-
_async_configured_socket,
83-
_get_timeout_details,
84+
_configured_protocol_interface,
8485
_raise_connection_failure,
8586
)
8687
from pymongo.read_concern import ReadConcern
@@ -93,10 +94,8 @@
9394
if TYPE_CHECKING:
9495
from pymongocrypt.mongocrypt import MongoCryptKmsContext
9596

96-
from pymongo.pyopenssl_context import _sslConn
9797
from pymongo.typings import _Address
9898

99-
10099
_IS_SYNC = False
101100

102101
_HTTPS_PORT = 443
@@ -111,9 +110,10 @@
111110
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
112111

113112

114-
async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
113+
async def _connect_kms(address: _Address, opts: PoolOptions) -> AsyncBaseConnection:
115114
try:
116-
return await _async_configured_socket(address, opts)
115+
interface = await _configured_protocol_interface(address, opts, PyMongoKMSProtocol)
116+
return AsyncBaseConnection(interface, opts)
117117
except Exception as exc:
118118
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
119119

@@ -198,18 +198,11 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
198198
try:
199199
conn = await _connect_kms(address, opts)
200200
try:
201-
await async_socket_sendall(conn, message)
201+
await async_sendall(conn.conn.get_conn, message)
202202
while kms_context.bytes_needed > 0:
203203
# CSOT: update timeout.
204-
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
205-
if _IS_SYNC:
206-
data = conn.recv(kms_context.bytes_needed)
207-
else:
208-
from pymongo.network_layer import ( # type: ignore[attr-defined]
209-
async_receive_data_socket,
210-
)
211-
212-
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
204+
conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
205+
data = await async_receive_kms(conn, kms_context.bytes_needed)
213206
if not data:
214207
raise OSError("KMS connection closed")
215208
kms_context.feed(data)
@@ -228,7 +221,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
228221
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
229222
)
230223
finally:
231-
conn.close()
224+
await conn.close_conn(None)
232225
except MongoCryptError:
233226
raise # Propagate MongoCryptError errors directly.
234227
except Exception as exc:

pymongo/asynchronous/pool.py

Lines changed: 91 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
WaitQueueTimeoutError,
5959
)
6060
from pymongo.hello import Hello, HelloCompat
61+
from pymongo.helpers_shared import _get_timeout_details, format_timeout_details
6162
from pymongo.lock import (
6263
_async_cond_wait,
6364
_async_create_condition,
@@ -79,9 +80,7 @@
7980
SSLErrors,
8081
_CancellationContext,
8182
_configured_protocol_interface,
82-
_get_timeout_details,
8383
_raise_connection_failure,
84-
format_timeout_details,
8584
)
8685
from pymongo.read_preferences import ReadPreference
8786
from pymongo.server_api import _add_to_command
@@ -124,7 +123,89 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
124123
_IS_SYNC = False
125124

126125

127-
class AsyncConnection:
126+
class AsyncBaseConnection:
127+
"""A base connection object for server and kms connections."""
128+
129+
def __init__(self, conn: AsyncNetworkingInterface, opts: PoolOptions):
130+
self.conn = conn
131+
self.socket_checker: SocketChecker = SocketChecker()
132+
self.cancel_context: _CancellationContext = _CancellationContext()
133+
self.is_sdam = False
134+
self.closed = False
135+
self.last_timeout: float | None = None
136+
self.more_to_come = False
137+
self.opts = opts
138+
self.max_wire_version = -1
139+
140+
def set_conn_timeout(self, timeout: Optional[float]) -> None:
141+
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
142+
if timeout == self.last_timeout:
143+
return
144+
self.last_timeout = timeout
145+
self.conn.get_conn.settimeout(timeout)
146+
147+
def apply_timeout(
148+
self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]]
149+
) -> Optional[float]:
150+
# CSOT: use remaining timeout when set.
151+
timeout = _csot.remaining()
152+
if timeout is None:
153+
# Reset the socket timeout unless we're performing a streaming monitor check.
154+
if not self.more_to_come:
155+
self.set_conn_timeout(self.opts.socket_timeout)
156+
return None
157+
# RTT validation.
158+
rtt = _csot.get_rtt()
159+
if rtt is None:
160+
rtt = self.connect_rtt
161+
max_time_ms = timeout - rtt
162+
if max_time_ms < 0:
163+
timeout_details = _get_timeout_details(self.opts)
164+
formatted = format_timeout_details(timeout_details)
165+
# CSOT: raise an error without running the command since we know it will time out.
166+
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
167+
if self.max_wire_version != -1:
168+
raise ExecutionTimeout(
169+
errmsg,
170+
50,
171+
{"ok": 0, "errmsg": errmsg, "code": 50},
172+
self.max_wire_version,
173+
)
174+
else:
175+
raise TimeoutError(errmsg)
176+
if cmd is not None:
177+
cmd["maxTimeMS"] = int(max_time_ms * 1000)
178+
self.set_conn_timeout(timeout)
179+
return timeout
180+
181+
async def close_conn(self, reason: Optional[str]) -> None:
182+
"""Close this connection with a reason."""
183+
if self.closed:
184+
return
185+
await self._close_conn()
186+
187+
async def _close_conn(self) -> None:
188+
"""Close this connection."""
189+
if self.closed:
190+
return
191+
self.closed = True
192+
self.cancel_context.cancel()
193+
# Note: We catch exceptions to avoid spurious errors on interpreter
194+
# shutdown.
195+
try:
196+
await self.conn.close()
197+
except Exception: # noqa: S110
198+
pass
199+
200+
def conn_closed(self) -> bool:
201+
"""Return True if we know socket has been closed, False otherwise."""
202+
if _IS_SYNC:
203+
return self.socket_checker.socket_closed(self.conn.get_conn)
204+
else:
205+
return self.conn.is_closing()
206+
207+
208+
class AsyncConnection(AsyncBaseConnection):
128209
"""Store a connection with some metadata.
129210
130211
:param conn: a raw connection object
@@ -142,29 +223,27 @@ def __init__(
142223
id: int,
143224
is_sdam: bool,
144225
):
226+
super().__init__(conn, pool.opts)
145227
self.pool_ref = weakref.ref(pool)
146-
self.conn = conn
147-
self.address = address
148-
self.id = id
228+
self.address: tuple[str, int] = address
229+
self.id: int = id
149230
self.is_sdam = is_sdam
150-
self.closed = False
151231
self.last_checkin_time = time.monotonic()
152232
self.performed_handshake = False
153233
self.is_writable: bool = False
154234
self.max_wire_version = MAX_WIRE_VERSION
155-
self.max_bson_size = MAX_BSON_SIZE
156-
self.max_message_size = MAX_MESSAGE_SIZE
157-
self.max_write_batch_size = MAX_WRITE_BATCH_SIZE
235+
self.max_bson_size: int = MAX_BSON_SIZE
236+
self.max_message_size: int = MAX_MESSAGE_SIZE
237+
self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE
158238
self.supports_sessions = False
159239
self.hello_ok: bool = False
160-
self.is_mongos = False
240+
self.is_mongos: bool = False
161241
self.op_msg_enabled = False
162242
self.listeners = pool.opts._event_listeners
163243
self.enabled_for_cmap = pool.enabled_for_cmap
164244
self.enabled_for_logging = pool.enabled_for_logging
165245
self.compression_settings = pool.opts._compression_settings
166246
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
167-
self.socket_checker: SocketChecker = SocketChecker()
168247
self.oidc_token_gen_id: Optional[int] = None
169248
# Support for mechanism negotiation on the initial handshake.
170249
self.negotiated_mechs: Optional[list[str]] = None
@@ -175,9 +254,6 @@ def __init__(
175254
self.pool_gen = pool.gen
176255
self.generation = self.pool_gen.get_overall()
177256
self.ready = False
178-
self.cancel_context: _CancellationContext = _CancellationContext()
179-
self.opts = pool.opts
180-
self.more_to_come: bool = False
181257
# For load balancer support.
182258
self.service_id: Optional[ObjectId] = None
183259
self.server_connection_id: Optional[int] = None
@@ -193,44 +269,6 @@ def __init__(
193269
# For gossiping $clusterTime from the connection handshake to the client.
194270
self._cluster_time = None
195271

196-
def set_conn_timeout(self, timeout: Optional[float]) -> None:
197-
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
198-
if timeout == self.last_timeout:
199-
return
200-
self.last_timeout = timeout
201-
self.conn.get_conn.settimeout(timeout)
202-
203-
def apply_timeout(
204-
self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]]
205-
) -> Optional[float]:
206-
# CSOT: use remaining timeout when set.
207-
timeout = _csot.remaining()
208-
if timeout is None:
209-
# Reset the socket timeout unless we're performing a streaming monitor check.
210-
if not self.more_to_come:
211-
self.set_conn_timeout(self.opts.socket_timeout)
212-
return None
213-
# RTT validation.
214-
rtt = _csot.get_rtt()
215-
if rtt is None:
216-
rtt = self.connect_rtt
217-
max_time_ms = timeout - rtt
218-
if max_time_ms < 0:
219-
timeout_details = _get_timeout_details(self.opts)
220-
formatted = format_timeout_details(timeout_details)
221-
# CSOT: raise an error without running the command since we know it will time out.
222-
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
223-
raise ExecutionTimeout(
224-
errmsg,
225-
50,
226-
{"ok": 0, "errmsg": errmsg, "code": 50},
227-
self.max_wire_version,
228-
)
229-
if cmd is not None:
230-
cmd["maxTimeMS"] = int(max_time_ms * 1000)
231-
self.set_conn_timeout(timeout)
232-
return timeout
233-
234272
def pin_txn(self) -> None:
235273
self.pinned_txn = True
236274
assert not self.pinned_cursor
@@ -574,26 +612,6 @@ async def close_conn(self, reason: Optional[str]) -> None:
574612
error=reason,
575613
)
576614

577-
async def _close_conn(self) -> None:
578-
"""Close this connection."""
579-
if self.closed:
580-
return
581-
self.closed = True
582-
self.cancel_context.cancel()
583-
# Note: We catch exceptions to avoid spurious errors on interpreter
584-
# shutdown.
585-
try:
586-
await self.conn.close()
587-
except Exception: # noqa: S110
588-
pass
589-
590-
def conn_closed(self) -> bool:
591-
"""Return True if we know socket has been closed, False otherwise."""
592-
if _IS_SYNC:
593-
return self.socket_checker.socket_closed(self.conn.get_conn)
594-
else:
595-
return self.conn.is_closing()
596-
597615
def send_cluster_time(
598616
self,
599617
command: MutableMapping[str, Any],

pymongo/asynchronous/server.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
_SDAMStatusMessage,
3939
)
4040
from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query
41-
from pymongo.pool_shared import _get_timeout_details, format_timeout_details
4241
from pymongo.response import PinnedResponse, Response
4342

4443
if TYPE_CHECKING:
@@ -225,11 +224,7 @@ async def run_operation(
225224
if use_cmd:
226225
first = docs[0]
227226
await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type]
228-
# Append timeout details to MaxTimeMSExpired responses.
229-
if first.get("code") == 50:
230-
timeout_details = _get_timeout_details(conn.opts) # type:ignore[has-type]
231-
first["errmsg"] += format_timeout_details(timeout_details) # type:ignore[index]
232-
_check_command_response(first, conn.max_wire_version)
227+
_check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type]
233228
except Exception as exc:
234229
duration = datetime.now() - start
235230
if isinstance(exc, (NotPrimaryError, OperationFailure)):

pymongo/helpers_shared.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
if TYPE_CHECKING:
4848
from pymongo.cursor_shared import _Hint
4949
from pymongo.operations import _IndexList
50+
from pymongo.pool_options import PoolOptions
5051
from pymongo.typings import _DocumentOut
5152

5253

@@ -108,6 +109,34 @@
108109
}
109110

110111

112+
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
113+
from pymongo import _csot
114+
115+
details = {}
116+
timeout = _csot.get_timeout()
117+
socket_timeout = options.socket_timeout
118+
connect_timeout = options.connect_timeout
119+
if timeout:
120+
details["timeoutMS"] = timeout * 1000
121+
if socket_timeout and not timeout:
122+
details["socketTimeoutMS"] = socket_timeout * 1000
123+
if connect_timeout:
124+
details["connectTimeoutMS"] = connect_timeout * 1000
125+
return details
126+
127+
128+
def format_timeout_details(details: Optional[dict[str, float]]) -> str:
129+
result = ""
130+
if details:
131+
result += " (configured timeouts:"
132+
for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]:
133+
if timeout in details:
134+
result += f" {timeout}: {details[timeout]}ms,"
135+
result = result[:-1]
136+
result += ")"
137+
return result
138+
139+
111140
def _gen_index_name(keys: _IndexList) -> str:
112141
"""Generate an index name from the set of fields it is over."""
113142
return "_".join(["{}_{}".format(*item) for item in keys])
@@ -188,6 +217,7 @@ def _check_command_response(
188217
max_wire_version: Optional[int],
189218
allowable_errors: Optional[Container[Union[int, str]]] = None,
190219
parse_write_concern_error: bool = False,
220+
pool_opts: Optional[PoolOptions] = None,
191221
) -> None:
192222
"""Check the response to a command for errors."""
193223
if "ok" not in response:
@@ -243,6 +273,10 @@ def _check_command_response(
243273
if code in (11000, 11001, 12582):
244274
raise DuplicateKeyError(errmsg, code, response, max_wire_version)
245275
elif code == 50:
276+
# Append timeout details to MaxTimeMSExpired responses.
277+
if pool_opts:
278+
timeout_details = _get_timeout_details(pool_opts)
279+
errmsg += format_timeout_details(timeout_details)
246280
raise ExecutionTimeout(errmsg, code, response, max_wire_version)
247281
elif code == 43:
248282
raise CursorNotFound(errmsg, code, response, max_wire_version)

0 commit comments

Comments
 (0)