Skip to content

Commit 8e6a30d

Browse files
committed
Fix async KMS
1 parent 3bdd381 commit 8e6a30d

File tree

3 files changed

+19
-32
lines changed

3 files changed

+19
-32
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,10 @@ async def _async_kms_request(
154154
while kms_context.bytes_needed > 0:
155155
# CSOT: update timeout.
156156
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
157-
data = await _async_receive_data_socket(conn, kms_context.bytes_needed, None)
158-
if not data:
159-
raise OSError("KMS connection closed")
157+
data = await _async_receive_data_socket(conn, kms_context.bytes_needed)
160158
kms_context.feed(data)
159+
except OSError as err:
160+
raise OSError("KMS connection closed") from err
161161
except BLOCKING_IO_ERRORS:
162162
raise socket.timeout("timed out") from None
163163
finally:
@@ -216,7 +216,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
216216
finally:
217217
conn.close()
218218
else:
219-
return await self._async_kms_request(kms_context, host, port, opts, message)
219+
await self._async_kms_request(kms_context, host, port, opts, message)
220220
except (PyMongoError, MongoCryptError):
221221
raise # Propagate pymongo errors directly.
222222
except Exception as error:

pymongo/network_layer.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _is_ready(fut: Future) -> None:
130130
loop.remove_writer(fd)
131131

132132
async def _async_receive_ssl(
133-
conn: _sslConn, length: int, loop: AbstractEventLoop
133+
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
134134
) -> memoryview:
135135
mv = memoryview(bytearray(length))
136136
total_read = 0
@@ -145,6 +145,8 @@ def _is_ready(fut: Future) -> None:
145145
read = conn.recv_into(mv[total_read:])
146146
if read == 0:
147147
raise OSError("connection closed")
148+
if once:
149+
return mv[:read]
148150
total_read += read
149151
except BLOCKING_IO_ERRORS as exc:
150152
fd = conn.fileno()
@@ -275,36 +277,21 @@ async def async_receive_data(
275277
sock.settimeout(sock_timeout)
276278

277279

278-
async def _async_receive_data_socket(
279-
sock: socket.socket | _sslConn, length: int, deadline: Optional[float]
280-
) -> memoryview:
280+
async def _async_receive_data_socket(sock: socket.socket | _sslConn, length: int) -> memoryview:
281281
sock_timeout = sock.gettimeout()
282-
timeout: Optional[Union[float, int]]
283-
if deadline:
284-
# When the timeout has expired perform one final check to
285-
# see if the socket is readable. This helps avoid spurious
286-
# timeouts on AWS Lambda and other FaaS environments.
287-
timeout = max(deadline - time.monotonic(), 0)
288-
else:
289-
timeout = sock_timeout
282+
timeout = sock_timeout
290283

291284
sock.settimeout(0.0)
292285
loop = asyncio.get_event_loop()
293286
try:
294287
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
295-
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
288+
return await asyncio.wait_for(
289+
_async_receive_ssl(sock, length, loop, once=True), timeout=timeout
290+
) # type: ignore[arg-type]
296291
else:
297-
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
298-
tasks = [read_task]
299-
done, pending = await asyncio.wait(
300-
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
301-
)
302-
for task in pending:
303-
task.cancel()
304-
await asyncio.wait(pending)
305-
if read_task in done:
306-
return read_task.result()
307-
raise socket.timeout("timed out")
292+
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
293+
except asyncio.TimeoutError as err:
294+
raise socket.timeout("timed out") from err
308295
finally:
309296
sock.settimeout(sock_timeout)
310297

pymongo/synchronous/encryption.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,10 @@ def _async_kms_request(
154154
while kms_context.bytes_needed > 0:
155155
# CSOT: update timeout.
156156
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
157-
data = _receive_data_socket(conn, kms_context.bytes_needed, None)
158-
if not data:
159-
raise OSError("KMS connection closed")
157+
data = _receive_data_socket(conn, kms_context.bytes_needed)
160158
kms_context.feed(data)
159+
except OSError as err:
160+
raise OSError("KMS connection closed") from err
161161
except BLOCKING_IO_ERRORS:
162162
raise socket.timeout("timed out") from None
163163
finally:
@@ -216,7 +216,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
216216
finally:
217217
conn.close()
218218
else:
219-
return self._async_kms_request(kms_context, host, port, opts, message)
219+
self._async_kms_request(kms_context, host, port, opts, message)
220220
except (PyMongoError, MongoCryptError):
221221
raise # Propagate pymongo errors directly.
222222
except Exception as error:

0 commit comments

Comments
 (0)