Skip to content

Commit e50685c

Browse files
committed
fix sync kms
1 parent 484aa9f commit e50685c

File tree

4 files changed

+96
-62
lines changed

4 files changed

+96
-62
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
if TYPE_CHECKING:
9595
from pymongocrypt.mongocrypt import MongoCryptKmsContext
9696

97+
from pymongo.typings import _Address
9798

9899
_IS_SYNC = False
99100

@@ -109,6 +110,14 @@
109110
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
110111

111112

113+
async def _connect_kms(address: _Address, opts: PoolOptions) -> AsyncBaseConnection:
114+
try:
115+
interface = await _configured_protocol_interface(address, opts, PyMongoKMSProtocol)
116+
return AsyncBaseConnection(interface, opts)
117+
except Exception as exc:
118+
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
119+
120+
112121
@contextlib.contextmanager
113122
def _wrap_encryption_errors() -> Iterator[None]:
114123
"""Context manager to wrap encryption related errors."""
@@ -187,13 +196,17 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
187196
sleep_sec = float(sleep_u) / 1e6
188197
await asyncio.sleep(sleep_sec)
189198
try:
190-
interface = await _configured_protocol_interface(address, opts, PyMongoKMSProtocol)
191-
conn = AsyncBaseConnection(interface, opts)
199+
conn = await _connect_kms(address, opts)
192200
try:
193-
await async_sendall(interface.get_conn, message)
201+
await async_sendall(conn.conn.get_conn, message)
194202
while kms_context.bytes_needed > 0:
195203
# CSOT: update timeout.
196-
interface.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
204+
conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
205+
# if _IS_SYNC:
206+
# # TODO: why can't we use receive_kms?
207+
# data = conn.conn.sock.recv(kms_context.bytes_needed)
208+
# else:
209+
# data = await async_receive_kms(conn, kms_context.bytes_needed)
197210
data = await async_receive_kms(conn, kms_context.bytes_needed)
198211
if not data:
199212
raise OSError("KMS connection closed")

pymongo/network_layer.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@
4141
from pymongo.socket_checker import _errno_from_exception
4242

4343
if TYPE_CHECKING:
44-
from pymongo.asynchronous.pool import AsyncBaseConnection
44+
from pymongo.asynchronous.pool import AsyncBaseConnection, AsyncConnection
4545
from pymongo.pyopenssl_context import _sslConn
46-
from pymongo.synchronous.pool import BaseConnection
46+
from pymongo.synchronous.pool import BaseConnection, Connection
4747

4848
_UNPACK_HEADER = struct.Struct("<iiii").unpack
4949
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
@@ -642,7 +642,7 @@ async def _async_receive_data(
642642

643643

644644
async def async_receive_message(
645-
conn: AsyncBaseConnection,
645+
conn: AsyncConnection,
646646
request_id: Optional[int],
647647
max_message_size: int = MAX_MESSAGE_SIZE,
648648
) -> Union[_OpReply, _OpMsg]:
@@ -663,19 +663,11 @@ def callback(result: Any) -> _OpMsg | _OpReply:
663663

664664
def receive_kms(conn: BaseConnection, bytes_needed: int) -> bytes:
665665
"""Receive raw bytes from the kms connection."""
666-
if _csot.get_timeout():
667-
deadline = _csot.get_deadline()
668-
else:
669-
timeout = conn.conn.gettimeout()
670-
if timeout:
671-
deadline = time.monotonic() + timeout
672-
else:
673-
deadline = None
674-
return receive_data(conn, bytes_needed, deadline)
666+
return conn.conn.sock.recv(bytes_needed)
675667

676668

677669
def receive_message(
678-
conn: BaseConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
670+
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
679671
) -> Union[_OpReply, _OpMsg]:
680672
"""Receive a raw BSON message or raise socket.error."""
681673
if _csot.get_timeout():

pymongo/synchronous/encryption.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
if TYPE_CHECKING:
9494
from pymongocrypt.mongocrypt import MongoCryptKmsContext
9595

96+
from pymongo.typings import _Address
9697

9798
_IS_SYNC = True
9899

@@ -108,6 +109,14 @@
108109
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
109110

110111

112+
def _connect_kms(address: _Address, opts: PoolOptions) -> BaseConnection:
113+
try:
114+
interface = _configured_socket_interface(address, opts, PyMongoKMSProtocol)
115+
return BaseConnection(interface, opts)
116+
except Exception as exc:
117+
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
118+
119+
111120
@contextlib.contextmanager
112121
def _wrap_encryption_errors() -> Iterator[None]:
113122
"""Context manager to wrap encryption related errors."""
@@ -186,13 +195,17 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
186195
sleep_sec = float(sleep_u) / 1e6
187196
time.sleep(sleep_sec)
188197
try:
189-
interface = _configured_socket_interface(address, opts, PyMongoKMSProtocol)
190-
conn = BaseConnection(interface, opts)
198+
conn = _connect_kms(address, opts)
191199
try:
192-
sendall(interface.get_conn, message)
200+
sendall(conn.conn.get_conn, message)
193201
while kms_context.bytes_needed > 0:
194202
# CSOT: update timeout.
195-
interface.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
203+
conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
204+
# if _IS_SYNC:
205+
# # TODO: why can't we use receive_kms?
206+
# data = conn.conn.sock.recv(kms_context.bytes_needed)
207+
# else:
208+
# data = receive_kms(conn, kms_context.bytes_needed)
196209
data = receive_kms(conn, kms_context.bytes_needed)
197210
if not data:
198211
raise OSError("KMS connection closed")

0 commit comments

Comments
 (0)