Skip to content

Commit 0a43477

Browse files
committed
wip
1 parent 8494954 commit 0a43477

File tree

4 files changed

+33
-24
lines changed

4 files changed

+33
-24
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
NetworkTimeout,
7777
ServerSelectionTimeoutError,
7878
)
79-
from pymongo.network_layer import async_receive_kms, async_sendall
79+
from pymongo.network_layer import PyMongoKMSProtocol, async_receive_kms, async_sendall
8080
from pymongo.operations import UpdateOne
8181
from pymongo.pool_options import PoolOptions
8282
from pymongo.pool_shared import (
@@ -187,16 +187,17 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
187187
sleep_sec = float(sleep_u) / 1e6
188188
await asyncio.sleep(sleep_sec)
189189
try:
190-
interface = await _configured_protocol_interface(address, opts)
190+
interface = await _configured_protocol_interface(address, opts, PyMongoKMSProtocol)
191191
conn = AsyncBaseConnection(interface, opts)
192192
try:
193193
await async_sendall(interface.get_conn, message)
194-
# CSOT: update timeout.
195-
interface.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
196-
data = await async_receive_kms(conn, kms_context.bytes_needed)
197-
if not data:
198-
raise OSError("KMS connection closed")
199-
kms_context.feed(bytes)
194+
while kms_context.bytes_needed > 0:
195+
# CSOT: update timeout.
196+
interface.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
197+
data = await async_receive_kms(conn, kms_context.bytes_needed)
198+
if not data:
199+
raise OSError("KMS connection closed")
200+
kms_context.feed(data)
200201
except MongoCryptError:
201202
raise # Propagate MongoCryptError errors directly.
202203
except Exception as exc:
@@ -212,7 +213,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
212213
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
213214
)
214215
finally:
215-
conn.close_conn()
216+
conn.close_conn(None)
216217
except MongoCryptError:
217218
raise # Propagate MongoCryptError errors directly.
218219
except Exception as exc:

pymongo/network_layer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def __init__(self, timeout: Optional[float] = None):
257257
self.transport: Transport = None # type: ignore[assignment]
258258
self._timeout = timeout
259259
self._closed = asyncio.get_running_loop().create_future()
260+
self._connection_lost = False
260261

261262
def settimeout(self, timeout: float | None) -> None:
262263
self._timeout = timeout
@@ -309,7 +310,6 @@ def __init__(self, timeout: Optional[float] = None):
309310
self._expecting_compression = False
310311
self._message_size = 0
311312
self._op_code = 0
312-
self._connection_lost = False
313313
self._read_waiter: Optional[Future[Any]] = None
314314
self._is_compressed = False
315315
self._compressor_id: Optional[int] = None
@@ -551,7 +551,8 @@ def buffer_updated(self, nbytes: int) -> None:
551551
buffer = self._buffers.popleft()
552552
buffer_remaining = len(buffer.buffer) - buffer.start_index
553553
# if we didn't exhaust the buffer, read the partial data and put it back.
554-
if buffer_remaining <= n_remaining:
554+
if buffer_remaining > n_remaining:
555+
n_remaining = 0
555556
output_buf[out_index : n_remaining + out_index] = buffer.buffer[
556557
buffer.start_index : buffer.start_index + n_remaining
557558
]
@@ -562,6 +563,7 @@ def buffer_updated(self, nbytes: int) -> None:
562563
output_buf[out_index : out_index + buffer_remaining] = buffer.buffer[
563564
buffer.start_index :
564565
]
566+
n_remaining -= buffer_remaining
565567
buffer.start_index = 0
566568
self._buffer_pool.append(buffer)
567569
waiter.set_result(output_buf)

pymongo/pool_shared.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@
3636
NetworkTimeout,
3737
_CertificateError,
3838
)
39-
from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol
39+
from pymongo.network_layer import (
40+
AsyncNetworkingInterface,
41+
NetworkingInterface,
42+
PyMongoBaseProtocol,
43+
PyMongoProtocol,
44+
)
4045
from pymongo.pool_options import PoolOptions
4146
from pymongo.ssl_support import PYSSLError, SSLError, _has_sni
4247

@@ -326,7 +331,7 @@ async def _async_configured_socket(
326331

327332

328333
async def _configured_protocol_interface(
329-
address: _Address, options: PoolOptions
334+
address: _Address, options: PoolOptions, protocol_kls: PyMongoBaseProtocol = PyMongoProtocol
330335
) -> AsyncNetworkingInterface:
331336
"""Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface.
332337
@@ -341,7 +346,7 @@ async def _configured_protocol_interface(
341346
if ssl_context is None:
342347
return AsyncNetworkingInterface(
343348
await asyncio.get_running_loop().create_connection(
344-
lambda: PyMongoProtocol(timeout=timeout), sock=sock
349+
lambda: protocol_kls(timeout=timeout), sock=sock
345350
)
346351
)
347352

@@ -350,7 +355,7 @@ async def _configured_protocol_interface(
350355
# We have to pass hostname / ip address to wrap_socket
351356
# to use SSLContext.check_hostname.
352357
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
353-
lambda: PyMongoProtocol(timeout=timeout),
358+
lambda: protocol_kls(timeout=timeout),
354359
sock=sock,
355360
server_hostname=host,
356361
ssl=ssl_context,

pymongo/synchronous/encryption.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
NetworkTimeout,
7171
ServerSelectionTimeoutError,
7272
)
73-
from pymongo.network_layer import receive_kms, sendall
73+
from pymongo.network_layer import PyMongoKMSProtocol, receive_kms, sendall
7474
from pymongo.operations import UpdateOne
7575
from pymongo.pool_options import PoolOptions
7676
from pymongo.pool_shared import (
@@ -186,16 +186,17 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
186186
sleep_sec = float(sleep_u) / 1e6
187187
time.sleep(sleep_sec)
188188
try:
189-
interface = _configured_socket_interface(address, opts)
189+
interface = _configured_socket_interface(address, opts, PyMongoKMSProtocol)
190190
conn = BaseConnection(interface, opts)
191191
try:
192192
sendall(interface.get_conn, message)
193-
# CSOT: update timeout.
194-
interface.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
195-
data = receive_kms(conn, kms_context.bytes_needed)
196-
if not data:
197-
raise OSError("KMS connection closed")
198-
kms_context.feed(bytes)
193+
while kms_context.bytes_needed > 0:
194+
# CSOT: update timeout.
195+
interface.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
196+
data = receive_kms(conn, kms_context.bytes_needed)
197+
if not data:
198+
raise OSError("KMS connection closed")
199+
kms_context.feed(data)
199200
except MongoCryptError:
200201
raise # Propagate MongoCryptError errors directly.
201202
except Exception as exc:
@@ -211,7 +212,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
211212
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
212213
)
213214
finally:
214-
conn.close_conn()
215+
conn.close_conn(None)
215216
except MongoCryptError:
216217
raise # Propagate MongoCryptError errors directly.
217218
except Exception as exc:

0 commit comments

Comments
 (0)