Skip to content

Commit 1b83cf8

Browse files
committed
Merge branch 'master' of github.com:mongodb/mongo-python-driver
2 parents a665877 + 3e5387e commit 1b83cf8

30 files changed

+4897
-443
lines changed

.evergreen/run-tests.sh

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ set -o xtrace
3030

3131
AUTH=${AUTH:-noauth}
3232
SSL=${SSL:-nossl}
33+
TEST_SUITES="test/ test/asynchronous/"
3334
TEST_ARGS="${*:1}"
3435

3536
export PIP_QUIET=1 # Quiet by default
@@ -95,7 +96,7 @@ if [ -n "$TEST_LOADBALANCER" ]; then
9596
export LOAD_BALANCER=1
9697
export SINGLE_MONGOS_LB_URI="${SINGLE_MONGOS_LB_URI:-mongodb://127.0.0.1:8000/?loadBalanced=true}"
9798
export MULTI_MONGOS_LB_URI="${MULTI_MONGOS_LB_URI:-mongodb://127.0.0.1:8001/?loadBalanced=true}"
98-
export TEST_ARGS="test/test_load_balancer.py"
99+
export TEST_SUITES="test/test_load_balancer.py"
99100
fi
100101

101102
if [ "$SSL" != "nossl" ]; then
@@ -171,9 +172,7 @@ if [ -n "$TEST_ENCRYPTION" ]; then
171172
export PATH=$CRYPT_SHARED_DIR:$PATH
172173
fi
173174
# Only run the encryption tests.
174-
if [ -z "$TEST_ARGS" ]; then
175-
TEST_ARGS="test/test_encryption.py"
176-
fi
175+
TEST_SUITES="test/test_encryption.py"
177176
fi
178177

179178
if [ -n "$TEST_FLE_AZURE_AUTO" ] || [ -n "$TEST_FLE_GCP_AUTO" ]; then
@@ -187,9 +186,7 @@ if [ -n "$TEST_FLE_AZURE_AUTO" ] || [ -n "$TEST_FLE_GCP_AUTO" ]; then
187186
exit 1
188187
fi
189188

190-
if [ -z "$TEST_ARGS" ]; then
191-
TEST_ARGS="test/test_on_demand_csfle.py"
192-
fi
189+
TEST_SUITES="test/test_on_demand_csfle.py"
193190
fi
194191

195192
if [ -n "$TEST_INDEX_MANAGEMENT" ]; then
@@ -198,36 +195,36 @@ if [ -n "$TEST_INDEX_MANAGEMENT" ]; then
198195
set +x
199196
export DB_PASSWORD="${DRIVERS_ATLAS_LAMBDA_PASSWORD}"
200197
set -x
201-
TEST_ARGS="test/test_index_management.py"
198+
TEST_SUITES="test/test_index_management.py"
202199
fi
203200

204201
if [ -n "$TEST_DATA_LAKE" ] && [ -z "$TEST_ARGS" ]; then
205-
TEST_ARGS="test/test_data_lake.py"
202+
TEST_SUITES="test/test_data_lake.py"
206203
fi
207204

208205
if [ -n "$TEST_ATLAS" ]; then
209-
TEST_ARGS="test/atlas/test_connection.py"
206+
TEST_SUITES="test/atlas/test_connection.py"
210207
fi
211208

212209
if [ -n "$TEST_OCSP" ]; then
213210
python -m pip install ".[ocsp]"
214-
TEST_ARGS="test/ocsp/test_ocsp.py"
211+
TEST_SUITES="test/ocsp/test_ocsp.py"
215212
fi
216213

217214
if [ -n "$TEST_AUTH_AWS" ]; then
218215
python -m pip install ".[aws]"
219-
TEST_ARGS="test/auth_aws/test_auth_aws.py"
216+
TEST_SUITES="test/auth_aws/test_auth_aws.py"
220217
fi
221218

222219
if [ -n "$TEST_AUTH_OIDC" ]; then
223220
python -m pip install ".[aws]"
224-
TEST_ARGS="test/auth_oidc/test_auth_oidc.py $TEST_ARGS"
221+
TEST_SUITES="test/auth_oidc/test_auth_oidc.py $TEST_ARGS"
225222
fi
226223

227224
if [ -n "$PERF_TEST" ]; then
228225
python -m pip install simplejson
229226
start_time=$(date +%s)
230-
TEST_ARGS="test/performance/perf_test.py"
227+
TEST_SUITES="test/performance/perf_test.py"
231228
fi
232229

233230
echo "Running $AUTH tests over $SSL with python $(which python)"
@@ -257,8 +254,7 @@ PIP_QUIET=0 python -m pip list
257254
if [ -z "$GREEN_FRAMEWORK" ]; then
258255
# Use --capture=tee-sys so pytest prints test output inline:
259256
# https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html
260-
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS
261-
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/asynchronous/ $TEST_ARGS
257+
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_SUITES $TEST_ARGS
262258
else
263259
python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS
264260
fi

.github/workflows/test-python.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,11 @@ jobs:
205205
python-version: '3.8'
206206
- name: Start MongoDB
207207
uses: supercharge/[email protected]
208-
- name: Run Test
208+
- name: Run connect test from sdist
209209
shell: bash
210210
run: |
211211
cd sdist/test
212212
ls
213213
which python
214214
pip install -e ".[test]"
215-
pytest -v
216-
pytest -v test/asynchronous/
215+
PYMONGO_MUST_CONNECT=1 pytest -v test/test_client_context.py

gridfs/asynchronous/grid_file.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ async def put(self, data: Any, **kwargs: Any) -> Any:
154154
"""
155155
async with AsyncGridIn(self._collection, **kwargs) as grid_file:
156156
await grid_file.write(data)
157-
return await grid_file._id
157+
return grid_file._id
158158

159159
async def get(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> AsyncGridOut:
160160
"""Get a file from GridFS by ``"_id"``.
@@ -220,7 +220,7 @@ async def get_version(
220220
query["filename"] = filename
221221

222222
_disallow_transactions(session)
223-
cursor = await self._files.find(query, session=session)
223+
cursor = self._files.find(query, session=session)
224224
if version is None:
225225
version = -1
226226
if version < 0:
@@ -923,7 +923,7 @@ async def open_download_stream_by_name(
923923
validate_string("filename", filename)
924924
query = {"filename": filename}
925925
_disallow_transactions(session)
926-
cursor = await self._files.find(query, session=session)
926+
cursor = self._files.find(query, session=session)
927927
if revision < 0:
928928
skip = abs(revision) - 1
929929
cursor.limit(-1).skip(skip).sort("uploadDate", DESCENDING)
@@ -1760,12 +1760,12 @@ def expected_chunk_length(self, chunk_n: int) -> int:
17601760
def __aiter__(self) -> _AsyncGridOutChunkIterator:
17611761
return self
17621762

1763-
async def _create_cursor(self) -> None:
1763+
def _create_cursor(self) -> None:
17641764
filter = {"files_id": self._id}
17651765
if self._next_chunk > 0:
17661766
filter["n"] = {"$gte": self._next_chunk}
17671767
_disallow_transactions(self._session)
1768-
self._cursor = await self._chunks.find(filter, sort=[("n", 1)], session=self._session)
1768+
self._cursor = self._chunks.find(filter, sort=[("n", 1)], session=self._session)
17691769

17701770
async def _next_with_retry(self) -> Mapping[str, Any]:
17711771
"""Return the next chunk and retry once on CursorNotFound.
@@ -1775,13 +1775,13 @@ async def _next_with_retry(self) -> Mapping[str, Any]:
17751775
server's default cursor timeout).
17761776
"""
17771777
if self._cursor is None:
1778-
await self._create_cursor()
1778+
self._create_cursor()
17791779
assert self._cursor is not None
17801780
try:
17811781
return await self._cursor.next()
17821782
except CursorNotFound:
17831783
await self._cursor.close()
1784-
await self._create_cursor()
1784+
self._create_cursor()
17851785
return await self._cursor.next()
17861786

17871787
async def next(self) -> Mapping[str, Any]:

pymongo/asynchronous/client_session.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@
144144
Any,
145145
AsyncContextManager,
146146
Callable,
147+
Coroutine,
147148
Mapping,
148149
MutableMapping,
149150
NoReturn,
@@ -598,7 +599,7 @@ def _inherit_option(self, name: str, val: _T) -> _T:
598599

599600
async def with_transaction(
600601
self,
601-
callback: Callable[[AsyncClientSession], _T],
602+
callback: Callable[[AsyncClientSession], Coroutine[Any, Any, _T]],
602603
read_concern: Optional[ReadConcern] = None,
603604
write_concern: Optional[WriteConcern] = None,
604605
read_preference: Optional[_ServerMode] = None,
@@ -693,7 +694,7 @@ async def callback(session, custom_arg, custom_kwarg=None):
693694
read_concern, write_concern, read_preference, max_commit_time_ms
694695
)
695696
try:
696-
ret = callback(self)
697+
ret = await callback(self)
697698
except Exception as exc:
698699
if self.in_transaction:
699700
await self.abort_transaction()

pymongo/asynchronous/collection.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,18 +1700,18 @@ async def find_one(
17001700
"""
17011701
if filter is not None and not isinstance(filter, abc.Mapping):
17021702
filter = {"_id": filter}
1703-
cursor = await self.find(filter, *args, **kwargs)
1703+
cursor = self.find(filter, *args, **kwargs)
17041704
async for result in cursor.limit(-1):
17051705
return result
17061706
return None
17071707

1708-
async def find(self, *args: Any, **kwargs: Any) -> AsyncCursor[_DocumentType]:
1708+
def find(self, *args: Any, **kwargs: Any) -> AsyncCursor[_DocumentType]:
17091709
"""Query the database.
17101710
17111711
The `filter` argument is a query document that all results
17121712
must match. For example:
17131713
1714-
>>> await db.test.find({"hello": "world"})
1714+
>>> db.test.find({"hello": "world"})
17151715
17161716
only matches documents that have a key "hello" with value
17171717
"world". Matches can have other keys *in addition* to
@@ -1891,9 +1891,7 @@ async def find(self, *args: Any, **kwargs: Any) -> AsyncCursor[_DocumentType]:
18911891
18921892
.. seealso:: The MongoDB documentation on `find <https://dochub.mongodb.org/core/find>`_.
18931893
"""
1894-
cursor = AsyncCursor(self, *args, **kwargs)
1895-
await cursor._supports_exhaust()
1896-
return cursor
1894+
return AsyncCursor(self, *args, **kwargs)
18971895

18981896
async def find_raw_batches(
18991897
self, *args: Any, **kwargs: Any

pymongo/asynchronous/command_cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,4 +424,4 @@ def _unpack_response( # type: ignore[override]
424424
return raw_response # type: ignore[return-value]
425425

426426
def __getitem__(self, index: int) -> NoReturn:
427-
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")
427+
raise InvalidOperation("Cannot call __getitem__ on AsyncRawBatchCommandCursor")

pymongo/asynchronous/cursor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,14 @@ def __init__(
237237
self._dbname = collection.database.name
238238
self._collname = collection.name
239239

240+
# Checking exhaust cursor support requires network IO
241+
if _IS_SYNC:
242+
self._exhaust_checked = True
243+
self._supports_exhaust() # type: ignore[unused-coroutine]
244+
else:
245+
self._exhaust = cursor_type == CursorType.EXHAUST
246+
self._exhaust_checked = False
247+
240248
async def _supports_exhaust(self) -> None:
241249
# Exhaust cursor support
242250
if self._cursor_type == CursorType.EXHAUST:
@@ -1242,6 +1250,9 @@ async def rewind(self) -> AsyncCursor[_DocumentType]:
12421250

12431251
async def next(self) -> _DocumentType:
12441252
"""Advance the cursor."""
1253+
if not self._exhaust_checked:
1254+
self._exhaust_checked = True
1255+
await self._supports_exhaust()
12451256
if self._empty:
12461257
raise StopAsyncIteration
12471258
if len(self._data) or await self._refresh():
@@ -1308,4 +1319,4 @@ async def explain(self) -> _DocumentType:
13081319
return await clone.explain()
13091320

13101321
def __getitem__(self, index: Any) -> NoReturn:
1311-
raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor")
1322+
raise InvalidOperation("Cannot call __getitem__ on AsyncRawBatchCursor")

pymongo/asynchronous/encryption.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ async def fetch_keys(self, filter: bytes) -> AsyncGenerator[bytes, None]:
260260
:return: A generator which yields the requested keys from the key vault.
261261
"""
262262
assert self.key_vault_coll is not None
263-
async with await self.key_vault_coll.find(RawBSONDocument(filter)) as cursor:
263+
async with self.key_vault_coll.find(RawBSONDocument(filter)) as cursor:
264264
async for key in cursor:
265265
yield key.raw
266266

@@ -975,7 +975,7 @@ async def get_key(self, id: Binary) -> Optional[RawBSONDocument]:
975975
assert self._key_vault_coll is not None
976976
return await self._key_vault_coll.find_one({"_id": id})
977977

978-
async def get_keys(self) -> AsyncCursor[RawBSONDocument]:
978+
def get_keys(self) -> AsyncCursor[RawBSONDocument]:
979979
"""Get all of the data keys.
980980
981981
:return: An instance of :class:`~pymongo.cursor.Cursor` over the data key
@@ -985,7 +985,7 @@ async def get_keys(self) -> AsyncCursor[RawBSONDocument]:
985985
"""
986986
self._check_closed()
987987
assert self._key_vault_coll is not None
988-
return await self._key_vault_coll.find({})
988+
return self._key_vault_coll.find({})
989989

990990
async def delete_key(self, id: Binary) -> DeleteResult:
991991
"""Delete a key document in the key vault collection that has the given ``key_id``.

pymongo/asynchronous/pool.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import collections
1819
import contextlib
20+
import functools
1921
import logging
2022
import os
2123
import socket
@@ -876,12 +878,23 @@ async def _configured_socket(
876878
if _IS_SYNC:
877879
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
878880
else:
879-
ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
881+
if hasattr(ssl_context, "a_wrap_socket"):
882+
ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
883+
else:
884+
loop = asyncio.get_running_loop()
885+
ssl_sock = await loop.run_in_executor(
886+
None,
887+
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc]
888+
)
880889
else:
881890
if _IS_SYNC:
882891
ssl_sock = ssl_context.wrap_socket(sock)
883892
else:
884-
ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
893+
if hasattr(ssl_context, "a_wrap_socket"):
894+
ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
895+
else:
896+
loop = asyncio.get_running_loop()
897+
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc]
885898
except _CertificateError:
886899
sock.close()
887900
# Raise _CertificateError directly like we do after match_hostname

pymongo/synchronous/collection.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,7 +1710,7 @@ def find(self, *args: Any, **kwargs: Any) -> Cursor[_DocumentType]:
17101710
The `filter` argument is a query document that all results
17111711
must match. For example:
17121712
1713-
>>> await db.test.find({"hello": "world"})
1713+
>>> db.test.find({"hello": "world"})
17141714
17151715
only matches documents that have a key "hello" with value
17161716
"world". Matches can have other keys *in addition* to
@@ -1890,9 +1890,7 @@ def find(self, *args: Any, **kwargs: Any) -> Cursor[_DocumentType]:
18901890
18911891
.. seealso:: The MongoDB documentation on `find <https://dochub.mongodb.org/core/find>`_.
18921892
"""
1893-
cursor = Cursor(self, *args, **kwargs)
1894-
cursor._supports_exhaust()
1895-
return cursor
1893+
return Cursor(self, *args, **kwargs)
18961894

18971895
def find_raw_batches(self, *args: Any, **kwargs: Any) -> RawBatchCursor[_DocumentType]:
18981896
"""Query the database and retrieve batches of raw BSON.

0 commit comments

Comments
 (0)