Skip to content

Commit a490cfd

Browse files
committed
PYTHON-5215 Add an asyncio.Protocol implementation for KMS
1 parent d7074ba commit a490cfd

File tree

6 files changed

+411
-471
lines changed

6 files changed

+411
-471
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from pymongo.asynchronous.cursor import AsyncCursor
6565
from pymongo.asynchronous.database import AsyncDatabase
6666
from pymongo.asynchronous.mongo_client import AsyncMongoClient
67+
from pymongo.asynchronous.pool import AsyncBaseConnection
6768
from pymongo.common import CONNECT_TIMEOUT
6869
from pymongo.daemon import _spawn_daemon
6970
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
@@ -75,11 +76,12 @@
7576
NetworkTimeout,
7677
ServerSelectionTimeoutError,
7778
)
78-
from pymongo.network_layer import async_socket_sendall
79+
from pymongo.network_layer import async_receive_kms, async_sendall
7980
from pymongo.operations import UpdateOne
8081
from pymongo.pool_options import PoolOptions
8182
from pymongo.pool_shared import (
8283
_async_configured_socket,
84+
_configured_protocol_interface,
8385
_get_timeout_details,
8486
_raise_connection_failure,
8587
)
@@ -196,23 +198,16 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
196198
sleep_sec = float(sleep_u) / 1e6
197199
await asyncio.sleep(sleep_sec)
198200
try:
199-
conn = await _connect_kms(address, opts)
201+
interface = await _configured_protocol_interface(address, opts)
202+
conn = AsyncBaseConnection(interface, opts)
203+
# Given a conn object, we want to send a message and then receive the bytes needed
200204
try:
201-
await async_socket_sendall(conn, message)
202-
while kms_context.bytes_needed > 0:
203-
# CSOT: update timeout.
204-
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
205-
if _IS_SYNC:
206-
data = conn.recv(kms_context.bytes_needed)
207-
else:
208-
from pymongo.network_layer import ( # type: ignore[attr-defined]
209-
async_receive_data_socket,
210-
)
211-
212-
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
213-
if not data:
214-
raise OSError("KMS connection closed")
215-
kms_context.feed(data)
205+
await async_sendall(interface.get_conn, message)
206+
interface.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
207+
data = await async_receive_kms(conn, kms_context.bytes_needed)
208+
if not data:
209+
raise OSError("KMS connection closed")
210+
kms_context.feed(bytes)
216211
except MongoCryptError:
217212
raise # Propagate MongoCryptError errors directly.
218213
except Exception as exc:
@@ -228,7 +223,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
228223
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
229224
)
230225
finally:
231-
conn.close()
226+
interface.get_conn.close()
232227
except MongoCryptError:
233228
raise # Propagate MongoCryptError errors directly.
234229
except Exception as exc:

pymongo/asynchronous/pool.py

Lines changed: 91 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,89 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
124124
_IS_SYNC = False
125125

126126

127-
class AsyncConnection:
127+
class AsyncBaseConnection:
128+
"""A base connection object for server and kms connections."""
129+
130+
def __init__(self, conn: AsyncNetworkingInterface, opts: PoolOptions):
131+
self.conn = conn
132+
self.socket_checker: SocketChecker = SocketChecker()
133+
self.cancel_context: _CancellationContext = _CancellationContext()
134+
self.is_sdam = False
135+
self.closed = False
136+
self.last_timeout: float | None = None
137+
self.more_to_come = False
138+
self.opts = opts
139+
self.max_wire_version = -1
140+
141+
def set_conn_timeout(self, timeout: Optional[float]) -> None:
142+
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
143+
if timeout == self.last_timeout:
144+
return
145+
self.last_timeout = timeout
146+
self.conn.get_conn.settimeout(timeout)
147+
148+
def apply_timeout(
149+
self, client: AsyncMongoClient, cmd: Optional[MutableMapping[str, Any]]
150+
) -> Optional[float]:
151+
# CSOT: use remaining timeout when set.
152+
timeout = _csot.remaining()
153+
if timeout is None:
154+
# Reset the socket timeout unless we're performing a streaming monitor check.
155+
if not self.more_to_come:
156+
self.set_conn_timeout(self.opts.socket_timeout)
157+
return None
158+
# RTT validation.
159+
rtt = _csot.get_rtt()
160+
if rtt is None:
161+
rtt = self.connect_rtt
162+
max_time_ms = timeout - rtt
163+
if max_time_ms < 0:
164+
timeout_details = _get_timeout_details(self.opts)
165+
formatted = format_timeout_details(timeout_details)
166+
# CSOT: raise an error without running the command since we know it will time out.
167+
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
168+
if self.max_wire_version != -1:
169+
raise ExecutionTimeout(
170+
errmsg,
171+
50,
172+
{"ok": 0, "errmsg": errmsg, "code": 50},
173+
self.max_wire_version,
174+
)
175+
else:
176+
raise TimeoutError(errmsg)
177+
if cmd is not None:
178+
cmd["maxTimeMS"] = int(max_time_ms * 1000)
179+
self.set_conn_timeout(timeout)
180+
return timeout
181+
182+
async def close_conn(self, reason: Optional[str]) -> None:
183+
"""Close this connection with a reason."""
184+
if self.closed:
185+
return
186+
await self._close_conn()
187+
188+
async def _close_conn(self) -> None:
189+
"""Close this connection."""
190+
if self.closed:
191+
return
192+
self.closed = True
193+
self.cancel_context.cancel()
194+
# Note: We catch exceptions to avoid spurious errors on interpreter
195+
# shutdown.
196+
try:
197+
await self.conn.close()
198+
except Exception: # noqa: S110
199+
pass
200+
201+
def conn_closed(self) -> bool:
202+
"""Return True if we know socket has been closed, False otherwise."""
203+
if _IS_SYNC:
204+
return self.socket_checker.socket_closed(self.conn.get_conn)
205+
else:
206+
return self.conn.is_closing()
207+
208+
209+
class AsyncConnection(AsyncBaseConnection):
128210
"""Store a connection with some metadata.
129211
130212
:param conn: a raw connection object
@@ -142,29 +224,27 @@ def __init__(
142224
id: int,
143225
is_sdam: bool,
144226
):
227+
super().__init__(conn, pool.opts)
145228
self.pool_ref = weakref.ref(pool)
146-
self.conn = conn
147-
self.address = address
148-
self.id = id
229+
self.address: tuple[str, int] = address
230+
self.id: int = id
149231
self.is_sdam = is_sdam
150-
self.closed = False
151232
self.last_checkin_time = time.monotonic()
152233
self.performed_handshake = False
153234
self.is_writable: bool = False
154235
self.max_wire_version = MAX_WIRE_VERSION
155-
self.max_bson_size = MAX_BSON_SIZE
156-
self.max_message_size = MAX_MESSAGE_SIZE
157-
self.max_write_batch_size = MAX_WRITE_BATCH_SIZE
236+
self.max_bson_size: int = MAX_BSON_SIZE
237+
self.max_message_size: int = MAX_MESSAGE_SIZE
238+
self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE
158239
self.supports_sessions = False
159240
self.hello_ok: bool = False
160-
self.is_mongos = False
241+
self.is_mongos: bool = False
161242
self.op_msg_enabled = False
162243
self.listeners = pool.opts._event_listeners
163244
self.enabled_for_cmap = pool.enabled_for_cmap
164245
self.enabled_for_logging = pool.enabled_for_logging
165246
self.compression_settings = pool.opts._compression_settings
166247
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
167-
self.socket_checker: SocketChecker = SocketChecker()
168248
self.oidc_token_gen_id: Optional[int] = None
169249
# Support for mechanism negotiation on the initial handshake.
170250
self.negotiated_mechs: Optional[list[str]] = None
@@ -175,9 +255,6 @@ def __init__(
175255
self.pool_gen = pool.gen
176256
self.generation = self.pool_gen.get_overall()
177257
self.ready = False
178-
self.cancel_context: _CancellationContext = _CancellationContext()
179-
self.opts = pool.opts
180-
self.more_to_come: bool = False
181258
# For load balancer support.
182259
self.service_id: Optional[ObjectId] = None
183260
self.server_connection_id: Optional[int] = None
@@ -193,44 +270,6 @@ def __init__(
193270
# For gossiping $clusterTime from the connection handshake to the client.
194271
self._cluster_time = None
195272

196-
def set_conn_timeout(self, timeout: Optional[float]) -> None:
197-
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
198-
if timeout == self.last_timeout:
199-
return
200-
self.last_timeout = timeout
201-
self.conn.get_conn.settimeout(timeout)
202-
203-
def apply_timeout(
204-
self, client: AsyncMongoClient, cmd: Optional[MutableMapping[str, Any]]
205-
) -> Optional[float]:
206-
# CSOT: use remaining timeout when set.
207-
timeout = _csot.remaining()
208-
if timeout is None:
209-
# Reset the socket timeout unless we're performing a streaming monitor check.
210-
if not self.more_to_come:
211-
self.set_conn_timeout(self.opts.socket_timeout)
212-
return None
213-
# RTT validation.
214-
rtt = _csot.get_rtt()
215-
if rtt is None:
216-
rtt = self.connect_rtt
217-
max_time_ms = timeout - rtt
218-
if max_time_ms < 0:
219-
timeout_details = _get_timeout_details(self.opts)
220-
formatted = format_timeout_details(timeout_details)
221-
# CSOT: raise an error without running the command since we know it will time out.
222-
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
223-
raise ExecutionTimeout(
224-
errmsg,
225-
50,
226-
{"ok": 0, "errmsg": errmsg, "code": 50},
227-
self.max_wire_version,
228-
)
229-
if cmd is not None:
230-
cmd["maxTimeMS"] = int(max_time_ms * 1000)
231-
self.set_conn_timeout(timeout)
232-
return timeout
233-
234273
def pin_txn(self) -> None:
235274
self.pinned_txn = True
236275
assert not self.pinned_cursor
@@ -555,9 +594,7 @@ def validate_session(
555594

556595
async def close_conn(self, reason: Optional[str]) -> None:
557596
"""Close this connection with a reason."""
558-
if self.closed:
559-
return
560-
await self._close_conn()
597+
await super().close_conn(reason)
561598
if reason:
562599
if self.enabled_for_cmap:
563600
assert self.listeners is not None
@@ -574,26 +611,6 @@ async def close_conn(self, reason: Optional[str]) -> None:
574611
error=reason,
575612
)
576613

577-
async def _close_conn(self) -> None:
578-
"""Close this connection."""
579-
if self.closed:
580-
return
581-
self.closed = True
582-
self.cancel_context.cancel()
583-
# Note: We catch exceptions to avoid spurious errors on interpreter
584-
# shutdown.
585-
try:
586-
await self.conn.close()
587-
except Exception: # noqa: S110
588-
pass
589-
590-
def conn_closed(self) -> bool:
591-
"""Return True if we know socket has been closed, False otherwise."""
592-
if _IS_SYNC:
593-
return self.socket_checker.socket_closed(self.conn.get_conn)
594-
else:
595-
return self.conn.is_closing()
596-
597614
def send_cluster_time(
598615
self,
599616
command: MutableMapping[str, Any],

0 commit comments

Comments
 (0)