Skip to content

Commit f83a882

Browse files
committed
Merge branch 'master' of github.com:mongodb/mongo-python-driver into add-justfile-jan
2 parents 29fa6b5 + e4d8449 commit f83a882

27 files changed

+229
-236
lines changed

bson/_cbsonmodule.c

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,56 @@ static int write_raw_doc(buffer_t buffer, PyObject* raw, PyObject* _raw_str) {
16441644
return bytes_written;
16451645
}
16461646

1647+
1648+
/* Update Invalid Document error message to include doc.
1649+
*/
1650+
void handle_invalid_doc_error(PyObject* dict) {
1651+
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
1652+
PyObject *msg = NULL, *dict_str = NULL, *new_msg = NULL;
1653+
PyErr_Fetch(&etype, &evalue, &etrace);
1654+
PyObject *InvalidDocument = _error("InvalidDocument");
1655+
if (InvalidDocument == NULL) {
1656+
goto cleanup;
1657+
}
1658+
1659+
if (evalue && PyErr_GivenExceptionMatches(etype, InvalidDocument)) {
1660+
PyObject *msg = PyObject_Str(evalue);
1661+
if (msg) {
1662+
// Prepend doc to the existing message
1663+
PyObject *dict_str = PyObject_Str(dict);
1664+
if (dict_str == NULL) {
1665+
goto cleanup;
1666+
}
1667+
const char * dict_str_utf8 = PyUnicode_AsUTF8(dict_str);
1668+
if (dict_str_utf8 == NULL) {
1669+
goto cleanup;
1670+
}
1671+
const char * msg_utf8 = PyUnicode_AsUTF8(msg);
1672+
if (msg_utf8 == NULL) {
1673+
goto cleanup;
1674+
}
1675+
PyObject *new_msg = PyUnicode_FromFormat("Invalid document %s | %s", dict_str_utf8, msg_utf8);
1676+
Py_DECREF(evalue);
1677+
Py_DECREF(etype);
1678+
etype = InvalidDocument;
1679+
InvalidDocument = NULL;
1680+
if (new_msg) {
1681+
evalue = new_msg;
1682+
} else {
1683+
evalue = msg;
1684+
}
1685+
}
1686+
PyErr_NormalizeException(&etype, &evalue, &etrace);
1687+
}
1688+
cleanup:
1689+
PyErr_Restore(etype, evalue, etrace);
1690+
Py_XDECREF(msg);
1691+
Py_XDECREF(InvalidDocument);
1692+
Py_XDECREF(dict_str);
1693+
Py_XDECREF(new_msg);
1694+
}
1695+
1696+
16471697
/* returns the number of bytes written or 0 on failure */
16481698
int write_dict(PyObject* self, buffer_t buffer,
16491699
PyObject* dict, unsigned char check_keys,
@@ -1743,40 +1793,8 @@ int write_dict(PyObject* self, buffer_t buffer,
17431793
while (PyDict_Next(dict, &pos, &key, &value)) {
17441794
if (!decode_and_write_pair(self, buffer, key, value,
17451795
check_keys, options, top_level)) {
1746-
if (PyErr_Occurred()) {
1747-
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
1748-
PyErr_Fetch(&etype, &evalue, &etrace);
1749-
PyObject *InvalidDocument = _error("InvalidDocument");
1750-
1751-
if (top_level && InvalidDocument && PyErr_GivenExceptionMatches(etype, InvalidDocument)) {
1752-
1753-
Py_DECREF(etype);
1754-
etype = InvalidDocument;
1755-
1756-
if (evalue) {
1757-
PyObject *msg = PyObject_Str(evalue);
1758-
Py_DECREF(evalue);
1759-
1760-
if (msg) {
1761-
// Prepend doc to the existing message
1762-
PyObject *dict_str = PyObject_Str(dict);
1763-
PyObject *new_msg = PyUnicode_FromFormat("Invalid document %s | %s", PyUnicode_AsUTF8(dict_str), PyUnicode_AsUTF8(msg));
1764-
Py_DECREF(dict_str);
1765-
1766-
if (new_msg) {
1767-
evalue = new_msg;
1768-
}
1769-
else {
1770-
evalue = msg;
1771-
}
1772-
}
1773-
}
1774-
PyErr_NormalizeException(&etype, &evalue, &etrace);
1775-
}
1776-
else {
1777-
Py_DECREF(InvalidDocument);
1778-
}
1779-
PyErr_Restore(etype, evalue, etrace);
1796+
if (PyErr_Occurred() && top_level) {
1797+
handle_invalid_doc_error(dict);
17801798
}
17811799
return 0;
17821800
}
@@ -1796,6 +1814,9 @@ int write_dict(PyObject* self, buffer_t buffer,
17961814
}
17971815
if (!decode_and_write_pair(self, buffer, key, value,
17981816
check_keys, options, top_level)) {
1817+
if (PyErr_Occurred() && top_level) {
1818+
handle_invalid_doc_error(dict);
1819+
}
17991820
Py_DECREF(key);
18001821
Py_DECREF(value);
18011822
Py_DECREF(iter);

pymongo/asynchronous/auth.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
_authenticate_oidc,
3939
_get_authenticator,
4040
)
41+
from pymongo.asynchronous.helpers import _getaddrinfo
4142
from pymongo.auth_shared import (
4243
MongoCredential,
4344
_authenticate_scram_start,
@@ -177,15 +178,22 @@ def _auth_key(nonce: str, username: str, password: str) -> str:
177178
return md5hash.hexdigest()
178179

179180

180-
def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
181+
async def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
181182
"""Canonicalize hostname following MIT-krb5 behavior."""
182183
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
183184
if option in [False, "none"]:
184185
return hostname
185186

186-
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
187-
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
188-
)[0]
187+
af, socktype, proto, canonname, sockaddr = (
188+
await _getaddrinfo(
189+
hostname,
190+
None,
191+
family=0,
192+
type=0,
193+
proto=socket.IPPROTO_TCP,
194+
flags=socket.AI_CANONNAME,
195+
)
196+
)[0] # type: ignore[index]
189197

190198
# For forward just to resolve the cname as dns.lookup() will not return it.
191199
if option == "forward":
@@ -213,7 +221,7 @@ async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnecti
213221
# Starting here and continuing through the while loop below - establish
214222
# the security context. See RFC 4752, Section 3.1, first paragraph.
215223
host = props.service_host or conn.address[0]
216-
host = _canonicalize_hostname(host, props.canonicalize_host_name)
224+
host = await _canonicalize_hostname(host, props.canonicalize_host_name)
217225
service = props.service_name + "@" + host
218226
if props.service_realm is not None:
219227
service = service + "@" + props.service_realm

pymongo/asynchronous/encryption.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,14 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
219219
# Wrap I/O errors in PyMongo exceptions.
220220
if isinstance(exc, BLOCKING_IO_ERRORS):
221221
exc = socket.timeout("timed out")
222-
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
222+
# Async raises an OSError instead of returning empty bytes.
223+
if isinstance(exc, OSError):
224+
msg_prefix = "KMS connection closed"
225+
else:
226+
msg_prefix = None
227+
_raise_connection_failure(
228+
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
229+
)
223230
finally:
224231
conn.close()
225232
except MongoCryptError:

pymongo/asynchronous/helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
"""Miscellaneous pieces that need to be synchronized."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import builtins
20+
import socket
1921
import sys
2022
from typing import (
2123
Any,
@@ -68,6 +70,24 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
6870
return cast(F, inner)
6971

7072

73+
async def _getaddrinfo(
74+
host: Any, port: Any, **kwargs: Any
75+
) -> list[
76+
tuple[
77+
socket.AddressFamily,
78+
socket.SocketKind,
79+
int,
80+
str,
81+
tuple[str, int] | tuple[str, int, int, int],
82+
]
83+
]:
84+
if not _IS_SYNC:
85+
loop = asyncio.get_running_loop()
86+
return await loop.getaddrinfo(host, port, **kwargs) # type: ignore[return-value]
87+
else:
88+
return socket.getaddrinfo(host, port, **kwargs)
89+
90+
7191
if sys.version_info >= (3, 10):
7292
anext = builtins.anext
7393
aiter = builtins.aiter

pymongo/asynchronous/pool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from bson import DEFAULT_CODEC_OPTIONS
4141
from pymongo import _csot, helpers_shared
4242
from pymongo.asynchronous.client_session import _validate_session_write_concern
43-
from pymongo.asynchronous.helpers import _handle_reauth
43+
from pymongo.asynchronous.helpers import _getaddrinfo, _handle_reauth
4444
from pymongo.asynchronous.network import command, receive_message
4545
from pymongo.common import (
4646
MAX_BSON_SIZE,
@@ -783,7 +783,7 @@ def __repr__(self) -> str:
783783
)
784784

785785

786-
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
786+
async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
787787
"""Given (host, port) and PoolOptions, connect and return a socket object.
788788
789789
Can raise socket.error.
@@ -814,7 +814,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
814814
family = socket.AF_UNSPEC
815815

816816
err = None
817-
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
817+
for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
818818
af, socktype, proto, dummy, sa = res
819819
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
820820
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
@@ -863,7 +863,7 @@ async def _configured_socket(
863863
864864
Sets socket's SSL and timeout options.
865865
"""
866-
sock = _create_connection(address, options)
866+
sock = await _create_connection(address, options)
867867
ssl_context = options._ssl_context
868868

869869
if ssl_context is None:

pymongo/message.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import datetime
2525
import random
2626
import struct
27-
from collections import ChainMap
2827
from io import BytesIO as _BytesIO
2928
from typing import (
3029
TYPE_CHECKING,
@@ -1116,18 +1115,8 @@ def _check_doc_size_limits(
11161115
# key and the index of its namespace within ns_info as its value.
11171116
op_doc[op_type] = ns_info[namespace] # type: ignore[index]
11181117

1119-
# Since the data document itself is nested within the insert document
1120-
# it won't be automatically re-ordered by the BSON conversion.
1121-
# We use ChainMap here to make the _id field the first field instead.
1122-
doc_to_encode = op_doc
1123-
if real_op_type == "insert":
1124-
doc = op_doc["document"]
1125-
if not isinstance(doc, RawBSONDocument):
1126-
doc_to_encode = op_doc.copy() # type: ignore[attr-defined] # Shallow copy
1127-
doc_to_encode["document"] = ChainMap(doc, {"_id": doc["_id"]}) # type: ignore[index]
1128-
11291118
# Encode current operation doc and, if newly added, namespace doc.
1130-
op_doc_encoded = _dict_to_bson(doc_to_encode, False, opts)
1119+
op_doc_encoded = _dict_to_bson(op_doc, False, opts)
11311120
op_length = len(op_doc_encoded)
11321121
if ns_doc:
11331122
ns_doc_encoded = _dict_to_bson(ns_doc, False, opts)

pymongo/network_layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
7373
timeout = sock.gettimeout()
7474
sock.settimeout(0.0)
75-
loop = asyncio.get_event_loop()
75+
loop = asyncio.get_running_loop()
7676
try:
7777
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
7878
await asyncio.wait_for(_async_sendall_ssl(sock, buf, loop), timeout=timeout)
@@ -259,7 +259,7 @@ async def async_receive_data(
259259
timeout = sock_timeout
260260

261261
sock.settimeout(0.0)
262-
loop = asyncio.get_event_loop()
262+
loop = asyncio.get_running_loop()
263263
cancellation_task = create_task(_poll_cancellation(conn))
264264
try:
265265
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
@@ -290,7 +290,7 @@ async def async_receive_data_socket(
290290
timeout = sock_timeout
291291

292292
sock.settimeout(0.0)
293-
loop = asyncio.get_event_loop()
293+
loop = asyncio.get_running_loop()
294294
try:
295295
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
296296
return await asyncio.wait_for(

pymongo/pyopenssl_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def __set_check_ocsp_endpoint(self, value: bool) -> None:
273273

274274
check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint)
275275

276-
def __get_options(self) -> None:
276+
def __get_options(self) -> int:
277277
# Calling set_options adds the option to the existing bitmask and
278278
# returns the new bitmask.
279279
# https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options

pymongo/synchronous/auth.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_authenticate_oidc,
4646
_get_authenticator,
4747
)
48+
from pymongo.synchronous.helpers import _getaddrinfo
4849

4950
if TYPE_CHECKING:
5051
from pymongo.hello import Hello
@@ -180,9 +181,16 @@ def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
180181
if option in [False, "none"]:
181182
return hostname
182183

183-
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
184-
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
185-
)[0]
184+
af, socktype, proto, canonname, sockaddr = (
185+
_getaddrinfo(
186+
hostname,
187+
None,
188+
family=0,
189+
type=0,
190+
proto=socket.IPPROTO_TCP,
191+
flags=socket.AI_CANONNAME,
192+
)
193+
)[0] # type: ignore[index]
186194

187195
# For forward just to resolve the cname as dns.lookup() will not return it.
188196
if option == "forward":

pymongo/synchronous/encryption.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,14 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
219219
# Wrap I/O errors in PyMongo exceptions.
220220
if isinstance(exc, BLOCKING_IO_ERRORS):
221221
exc = socket.timeout("timed out")
222-
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
222+
# Async raises an OSError instead of returning empty bytes.
223+
if isinstance(exc, OSError):
224+
msg_prefix = "KMS connection closed"
225+
else:
226+
msg_prefix = None
227+
_raise_connection_failure(
228+
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
229+
)
223230
finally:
224231
conn.close()
225232
except MongoCryptError:

0 commit comments

Comments
 (0)