Skip to content

Commit 4b25e95

Browse files
committed
Fix synchro
1 parent 2d0f4c1 commit 4b25e95

File tree

4 files changed

+118
-624
lines changed

4 files changed

+118
-624
lines changed

pymongo/synchronous/auth.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,20 @@ def _auth_key(nonce: str, username: str, password: str) -> str:
174174
return md5hash.hexdigest()
175175

176176

177-
def _canonicalize_hostname(hostname: str) -> str:
177+
def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
178178
"""Canonicalize hostname following MIT-krb5 behavior."""
179179
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
180+
if option in [False, "none"]:
181+
return hostname
182+
180183
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
181184
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
182185
)[0]
183186

187+
# For forward just to resolve the cname as dns.lookup() will not return it.
188+
if option == "forward":
189+
return canonname.lower()
190+
184191
try:
185192
name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD)
186193
except socket.gaierror:
@@ -202,9 +209,8 @@ def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None
202209
props = credentials.mechanism_properties
203210
# Starting here and continuing through the while loop below - establish
204211
# the security context. See RFC 4752, Section 3.1, first paragraph.
205-
host = conn.address[0]
206-
if props.canonicalize_host_name:
207-
host = _canonicalize_hostname(host)
212+
host = props.service_host or conn.address[0]
213+
host = _canonicalize_hostname(host, props.canonicalize_host_name)
208214
service = props.service_name + "@" + host
209215
if props.service_realm is not None:
210216
service = service + "@" + props.service_realm

pymongo/synchronous/encryption.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import contextlib
2020
import enum
2121
import socket
22+
import time as time # noqa: PLC0414 # needed in sync version
2223
import uuid
2324
import weakref
2425
from copy import deepcopy
@@ -67,7 +68,7 @@
6768
EncryptedCollectionError,
6869
EncryptionError,
6970
InvalidOperation,
70-
PyMongoError,
71+
NetworkTimeout,
7172
ServerSelectionTimeoutError,
7273
)
7374
from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall
@@ -80,14 +81,21 @@
8081
from pymongo.synchronous.cursor import Cursor
8182
from pymongo.synchronous.database import Database
8283
from pymongo.synchronous.mongo_client import MongoClient
83-
from pymongo.synchronous.pool import _configured_socket, _raise_connection_failure
84+
from pymongo.synchronous.pool import (
85+
_configured_socket,
86+
_get_timeout_details,
87+
_raise_connection_failure,
88+
)
8489
from pymongo.typings import _DocumentType, _DocumentTypeArg
8590
from pymongo.uri_parser import parse_host
8691
from pymongo.write_concern import WriteConcern
8792

8893
if TYPE_CHECKING:
8994
from pymongocrypt.mongocrypt import MongoCryptKmsContext
9095

96+
from pymongo.pyopenssl_context import _sslConn
97+
from pymongo.typings import _Address
98+
9199

92100
_IS_SYNC = True
93101

@@ -103,6 +111,13 @@
103111
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
104112

105113

114+
def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
115+
try:
116+
return _configured_socket(address, opts)
117+
except Exception as exc:
118+
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
119+
120+
106121
@contextlib.contextmanager
107122
def _wrap_encryption_errors() -> Iterator[None]:
108123
"""Context manager to wrap encryption related errors."""
@@ -166,18 +181,22 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
166181
None, # crlfile
167182
False, # allow_invalid_certificates
168183
False, # allow_invalid_hostnames
169-
False,
170-
) # disable_ocsp_endpoint_check
184+
False, # disable_ocsp_endpoint_check
185+
)
171186
# CSOT: set timeout for socket creation.
172187
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
173188
opts = PoolOptions(
174189
connect_timeout=connect_timeout,
175190
socket_timeout=connect_timeout,
176191
ssl_context=ctx,
177192
)
178-
host, port = parse_host(endpoint, _HTTPS_PORT)
193+
address = parse_host(endpoint, _HTTPS_PORT)
194+
sleep_u = kms_context.usleep
195+
if sleep_u:
196+
sleep_sec = float(sleep_u) / 1e6
197+
time.sleep(sleep_sec)
179198
try:
180-
conn = _configured_socket((host, port), opts)
199+
conn = _connect_kms(address, opts)
181200
try:
182201
sendall(conn, message)
183202
while kms_context.bytes_needed > 0:
@@ -194,20 +213,29 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
194213
if not data:
195214
raise OSError("KMS connection closed")
196215
kms_context.feed(data)
197-
# Async raises an OSError instead of returning empty bytes
198-
except OSError as err:
199-
raise OSError("KMS connection closed") from err
200-
except BLOCKING_IO_ERRORS:
201-
raise socket.timeout("timed out") from None
216+
except MongoCryptError:
217+
raise # Propagate MongoCryptError errors directly.
218+
except Exception as exc:
219+
# Wrap I/O errors in PyMongo exceptions.
220+
if isinstance(exc, BLOCKING_IO_ERRORS):
221+
exc = socket.timeout("timed out")
222+
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
202223
finally:
203224
conn.close()
204-
except (PyMongoError, MongoCryptError):
205-
raise # Propagate pymongo errors directly.
206-
except asyncio.CancelledError:
207-
raise
208-
except Exception as error:
209-
# Wrap I/O errors in PyMongo exceptions.
210-
_raise_connection_failure((host, port), error)
225+
except MongoCryptError:
226+
raise # Propagate MongoCryptError errors directly.
227+
except Exception as exc:
228+
remaining = _csot.remaining()
229+
if isinstance(exc, NetworkTimeout) or (remaining is not None and remaining <= 0):
230+
raise
231+
# Mark this attempt as failed and defer to libmongocrypt to retry.
232+
try:
233+
kms_context.fail()
234+
except MongoCryptError as final_err:
235+
exc = MongoCryptError(
236+
f"{final_err}, last attempt failed with: {exc}", final_err.code
237+
)
238+
raise exc from final_err
211239

212240
def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
213241
"""Get the collection info for a namespace.

pymongo/synchronous/network.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import datetime
1919
import logging
20+
import time
2021
from typing import (
2122
TYPE_CHECKING,
2223
Any,
@@ -30,16 +31,20 @@
3031

3132
from bson import _decode_all_selective
3233
from pymongo import _csot, helpers_shared, message
33-
from pymongo.compression_support import _NO_COMPRESSION
34+
from pymongo.common import MAX_MESSAGE_SIZE
35+
from pymongo.compression_support import _NO_COMPRESSION, decompress
3436
from pymongo.errors import (
3537
NotPrimaryError,
3638
OperationFailure,
39+
ProtocolError,
3740
)
3841
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
39-
from pymongo.message import _OpMsg
42+
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
4043
from pymongo.monitoring import _is_speculative_authenticate
4144
from pymongo.network_layer import (
42-
receive_message,
45+
_UNPACK_COMPRESSION_HEADER,
46+
_UNPACK_HEADER,
47+
receive_data,
4348
sendall,
4449
)
4550

@@ -51,15 +56,15 @@
5156
from pymongo.read_preferences import _ServerMode
5257
from pymongo.synchronous.client_session import ClientSession
5358
from pymongo.synchronous.mongo_client import MongoClient
54-
from pymongo.synchronous.pool import ConnectionProtocol
59+
from pymongo.synchronous.pool import Connection
5560
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
5661
from pymongo.write_concern import WriteConcern
5762

5863
_IS_SYNC = True
5964

6065

6166
def command(
62-
conn: ConnectionProtocol,
67+
conn: Connection,
6368
dbname: str,
6469
spec: MutableMapping[str, Any],
6570
is_mongos: bool,
@@ -189,7 +194,7 @@ def command(
189194
)
190195

191196
try:
192-
sendall(conn, msg)
197+
sendall(conn.conn, msg)
193198
if use_op_msg and unacknowledged:
194199
# Unacknowledged, fake a successful command response.
195200
reply = None
@@ -292,3 +297,45 @@ def command(
292297
)
293298

294299
return response_doc # type: ignore[return-value]
300+
301+
302+
def receive_message(
303+
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
304+
) -> Union[_OpReply, _OpMsg]:
305+
"""Receive a raw BSON message or raise socket.error."""
306+
if _csot.get_timeout():
307+
deadline = _csot.get_deadline()
308+
else:
309+
timeout = conn.conn.gettimeout()
310+
if timeout:
311+
deadline = time.monotonic() + timeout
312+
else:
313+
deadline = None
314+
# Ignore the response's request id.
315+
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
316+
# No request_id for exhaust cursor "getMore".
317+
if request_id is not None:
318+
if request_id != response_to:
319+
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
320+
if length <= 16:
321+
raise ProtocolError(
322+
f"Message length ({length!r}) not longer than standard message header size (16)"
323+
)
324+
if length > max_message_size:
325+
raise ProtocolError(
326+
f"Message length ({length!r}) is larger than server max "
327+
f"message size ({max_message_size!r})"
328+
)
329+
if op_code == 2012:
330+
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
331+
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
332+
else:
333+
data = receive_data(conn, length - 16, deadline)
334+
335+
try:
336+
unpack_reply = _UNPACK_REPLY[op_code]
337+
except KeyError:
338+
raise ProtocolError(
339+
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
340+
) from None
341+
return unpack_reply(data)

0 commit comments

Comments
 (0)