Skip to content

Commit b472d3a

Browse files
committed
Async KMS
1 parent 17388d1 commit b472d3a

File tree

3 files changed

+120
-28
lines changed

3 files changed

+120
-28
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,32 @@ def __init__(
142142
self.opts = opts
143143
self._spawned = False
144144

145+
async def _async_kms_request(
146+
self, kms_context: MongoCryptKmsContext, host, port, opts, message
147+
) -> None:
148+
from pymongo.network_layer import _async_receive_data_socket
149+
150+
try:
151+
conn = await _configured_socket((host, port), opts)
152+
try:
153+
await async_sendall(conn, message)
154+
while kms_context.bytes_needed > 0:
155+
# CSOT: update timeout.
156+
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
157+
data = await _async_receive_data_socket(conn, kms_context.bytes_needed)
158+
if not data:
159+
raise OSError("KMS connection closed")
160+
kms_context.feed(data)
161+
except BLOCKING_IO_ERRORS:
162+
raise socket.timeout("timed out") from None
163+
finally:
164+
conn.close()
165+
except (PyMongoError, MongoCryptError):
166+
raise # Propagate pymongo errors directly.
167+
except Exception as error:
168+
# Wrap I/O errors in PyMongo exceptions.
169+
_raise_connection_failure((host, port), error)
170+
145171
async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
146172
"""Complete a KMS request.
147173
@@ -174,20 +200,23 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
174200
)
175201
host, port = parse_host(endpoint, _HTTPS_PORT)
176202
try:
177-
conn = await _configured_socket((host, port), opts)
178-
try:
179-
await async_sendall(conn, message)
180-
while kms_context.bytes_needed > 0:
181-
# CSOT: update timeout.
182-
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
183-
data = conn.recv(kms_context.bytes_needed)
184-
if not data:
185-
raise OSError("KMS connection closed")
186-
kms_context.feed(data)
187-
except BLOCKING_IO_ERRORS:
188-
raise socket.timeout("timed out") from None
189-
finally:
190-
conn.close()
203+
if _IS_SYNC:
204+
conn = await _configured_socket((host, port), opts)
205+
try:
206+
await async_sendall(conn, message)
207+
while kms_context.bytes_needed > 0:
208+
# CSOT: update timeout.
209+
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
210+
data = conn.recv(kms_context.bytes_needed)
211+
if not data:
212+
raise OSError("KMS connection closed")
213+
kms_context.feed(data)
214+
except BLOCKING_IO_ERRORS:
215+
raise socket.timeout("timed out") from None
216+
finally:
217+
conn.close()
218+
else:
219+
return await self._async_kms_request(kms_context, host, port, opts, message)
191220
except (PyMongoError, MongoCryptError):
192221
raise # Propagate pymongo errors directly.
193222
except Exception as error:

pymongo/network_layer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,40 @@ async def async_receive_data(
275275
sock.settimeout(sock_timeout)
276276

277277

278+
async def _async_receive_data_socket(
279+
sock: socket.socket | _sslConn, length: int, deadline: Optional[float]
280+
) -> memoryview:
281+
sock_timeout = sock.gettimeout()
282+
timeout: Optional[Union[float, int]]
283+
if deadline:
284+
# When the timeout has expired perform one final check to
285+
# see if the socket is readable. This helps avoid spurious
286+
# timeouts on AWS Lambda and other FaaS environments.
287+
timeout = max(deadline - time.monotonic(), 0)
288+
else:
289+
timeout = sock_timeout
290+
291+
sock.settimeout(0.0)
292+
loop = asyncio.get_event_loop()
293+
try:
294+
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
295+
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
296+
else:
297+
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
298+
tasks = [read_task]
299+
done, pending = await asyncio.wait(
300+
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
301+
)
302+
for task in pending:
303+
task.cancel()
304+
await asyncio.wait(pending)
305+
if read_task in done:
306+
return read_task.result()
307+
raise socket.timeout("timed out")
308+
finally:
309+
sock.settimeout(sock_timeout)
310+
311+
278312
async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
279313
mv = memoryview(bytearray(length))
280314
bytes_read = 0

pymongo/synchronous/encryption.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,32 @@ def __init__(
142142
self.opts = opts
143143
self._spawned = False
144144

145+
def _async_kms_request(
146+
self, kms_context: MongoCryptKmsContext, host, port, opts, message
147+
) -> None:
148+
from pymongo.network_layer import _receive_data_socket
149+
150+
try:
151+
conn = _configured_socket((host, port), opts)
152+
try:
153+
sendall(conn, message)
154+
while kms_context.bytes_needed > 0:
155+
# CSOT: update timeout.
156+
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
157+
data = _receive_data_socket(conn, kms_context.bytes_needed)
158+
if not data:
159+
raise OSError("KMS connection closed")
160+
kms_context.feed(data)
161+
except BLOCKING_IO_ERRORS:
162+
raise socket.timeout("timed out") from None
163+
finally:
164+
conn.close()
165+
except (PyMongoError, MongoCryptError):
166+
raise # Propagate pymongo errors directly.
167+
except Exception as error:
168+
# Wrap I/O errors in PyMongo exceptions.
169+
_raise_connection_failure((host, port), error)
170+
145171
def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
146172
"""Complete a KMS request.
147173
@@ -174,20 +200,23 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
174200
)
175201
host, port = parse_host(endpoint, _HTTPS_PORT)
176202
try:
177-
conn = _configured_socket((host, port), opts)
178-
try:
179-
sendall(conn, message)
180-
while kms_context.bytes_needed > 0:
181-
# CSOT: update timeout.
182-
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
183-
data = conn.recv(kms_context.bytes_needed)
184-
if not data:
185-
raise OSError("KMS connection closed")
186-
kms_context.feed(data)
187-
except BLOCKING_IO_ERRORS:
188-
raise socket.timeout("timed out") from None
189-
finally:
190-
conn.close()
203+
if _IS_SYNC:
204+
conn = _configured_socket((host, port), opts)
205+
try:
206+
sendall(conn, message)
207+
while kms_context.bytes_needed > 0:
208+
# CSOT: update timeout.
209+
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
210+
data = conn.recv(kms_context.bytes_needed)
211+
if not data:
212+
raise OSError("KMS connection closed")
213+
kms_context.feed(data)
214+
except BLOCKING_IO_ERRORS:
215+
raise socket.timeout("timed out") from None
216+
finally:
217+
conn.close()
218+
else:
219+
return self._async_kms_request(kms_context, host, port, opts, message)
191220
except (PyMongoError, MongoCryptError):
192221
raise # Propagate pymongo errors directly.
193222
except Exception as error:

0 commit comments

Comments
 (0)