Skip to content

Commit a3cd704

Browse files
authored
PYTHON-4549 - Optimize Cursor.to_list (mongodb#1749)
1 parent d79eee5 commit a3cd704

File tree

8 files changed

+150
-4
lines changed

8 files changed

+150
-4
lines changed

gridfs/asynchronous/grid_file.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1892,6 +1892,9 @@ async def next(self) -> AsyncGridOut:
18921892
next_file = await super().next()
18931893
return AsyncGridOut(self._root_collection, file_document=next_file, session=self.session)
18941894

1895+
async def to_list(self) -> list[AsyncGridOut]:
1896+
return [x async for x in self] # noqa: C416,RUF100
1897+
18951898
__anext__ = next
18961899

18971900
def add_option(self, *args: Any, **kwargs: Any) -> NoReturn:

gridfs/synchronous/grid_file.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,6 +1878,9 @@ def next(self) -> GridOut:
18781878
next_file = super().next()
18791879
return GridOut(self._root_collection, file_document=next_file, session=self.session)
18801880

1881+
def to_list(self) -> list[GridOut]:
1882+
return [x for x in self] # noqa: C416,RUF100
1883+
18811884
__next__ = next
18821885

18831886
def add_option(self, *args: Any, **kwargs: Any) -> NoReturn:

pymongo/asynchronous/command_cursor.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,17 @@ async def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
346346
else:
347347
return None
348348

349+
async def _next_batch(self, result: list) -> bool:
350+
"""Get all available documents from the cursor."""
351+
if not len(self._data) and not self._killed:
352+
await self._refresh()
353+
if len(self._data):
354+
result.extend(self._data)
355+
self._data.clear()
356+
return True
357+
else:
358+
return False
359+
349360
async def try_next(self) -> Optional[_DocumentType]:
350361
"""Advance the cursor without blocking indefinitely.
351362
@@ -371,7 +382,11 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
371382
await self.close()
372383

373384
async def to_list(self) -> list[_DocumentType]:
374-
return [x async for x in self] # noqa: C416,RUF100
385+
res: list[_DocumentType] = []
386+
while self.alive:
387+
if not await self._next_batch(res):
388+
break
389+
return res
375390

376391

377392
class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):

pymongo/asynchronous/cursor.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,20 @@ async def next(self) -> _DocumentType:
12601260
else:
12611261
raise StopAsyncIteration
12621262

1263+
async def _next_batch(self, result: list) -> bool:
1264+
"""Get all available documents from the cursor."""
1265+
if not self._exhaust_checked:
1266+
self._exhaust_checked = True
1267+
await self._supports_exhaust()
1268+
if self._empty:
1269+
return False
1270+
if len(self._data) or await self._refresh():
1271+
result.extend(self._data)
1272+
self._data.clear()
1273+
return True
1274+
else:
1275+
return False
1276+
12631277
async def __anext__(self) -> _DocumentType:
12641278
return await self.next()
12651279

@@ -1273,7 +1287,11 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
12731287
await self.close()
12741288

12751289
async def to_list(self) -> list[_DocumentType]:
1276-
return [x async for x in self] # noqa: C416,RUF100
1290+
res: list[_DocumentType] = []
1291+
while self.alive:
1292+
if not await self._next_batch(res):
1293+
break
1294+
return res
12771295

12781296

12791297
class AsyncRawBatchCursor(AsyncCursor, Generic[_DocumentType]):

pymongo/synchronous/command_cursor.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,17 @@ def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
346346
else:
347347
return None
348348

349+
def _next_batch(self, result: list) -> bool:
350+
"""Get all available documents from the cursor."""
351+
if not len(self._data) and not self._killed:
352+
self._refresh()
353+
if len(self._data):
354+
result.extend(self._data)
355+
self._data.clear()
356+
return True
357+
else:
358+
return False
359+
349360
def try_next(self) -> Optional[_DocumentType]:
350361
"""Advance the cursor without blocking indefinitely.
351362
@@ -371,7 +382,11 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
371382
self.close()
372383

373384
def to_list(self) -> list[_DocumentType]:
374-
return [x for x in self] # noqa: C416,RUF100
385+
res: list[_DocumentType] = []
386+
while self.alive:
387+
if not self._next_batch(res):
388+
break
389+
return res
375390

376391

377392
class RawBatchCommandCursor(CommandCursor[_DocumentType]):

pymongo/synchronous/cursor.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,20 @@ def next(self) -> _DocumentType:
12581258
else:
12591259
raise StopIteration
12601260

1261+
def _next_batch(self, result: list) -> bool:
1262+
"""Get all available documents from the cursor."""
1263+
if not self._exhaust_checked:
1264+
self._exhaust_checked = True
1265+
self._supports_exhaust()
1266+
if self._empty:
1267+
return False
1268+
if len(self._data) or self._refresh():
1269+
result.extend(self._data)
1270+
self._data.clear()
1271+
return True
1272+
else:
1273+
return False
1274+
12611275
def __next__(self) -> _DocumentType:
12621276
return self.next()
12631277

@@ -1271,7 +1285,11 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
12711285
self.close()
12721286

12731287
def to_list(self) -> list[_DocumentType]:
1274-
return [x for x in self] # noqa: C416,RUF100
1288+
res: list[_DocumentType] = []
1289+
while self.alive:
1290+
if not self._next_batch(res):
1291+
break
1292+
return res
12751293

12761294

12771295
class RawBatchCursor(Cursor, Generic[_DocumentType]):

test/asynchronous/test_cursor.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,43 @@ async def test_getMore_does_not_send_readPreference(self):
13801380
self.assertEqual("getMore", started[1].command_name)
13811381
self.assertNotIn("$readPreference", started[1].command)
13821382

1383+
@async_client_context.require_replica_set
1384+
async def test_to_list_tailable(self):
1385+
oplog = self.client.local.oplog.rs
1386+
last = await oplog.find().sort("$natural", pymongo.DESCENDING).limit(-1).next()
1387+
ts = last["ts"]
1388+
1389+
c = oplog.find(
1390+
{"ts": {"$gte": ts}}, cursor_type=pymongo.CursorType.TAILABLE_AWAIT, oplog_replay=True
1391+
)
1392+
1393+
docs = await c.to_list()
1394+
1395+
self.assertGreaterEqual(len(docs), 1)
1396+
1397+
async def test_to_list_empty(self):
1398+
c = self.db.does_not_exist.find()
1399+
1400+
docs = await c.to_list()
1401+
1402+
self.assertEqual([], docs)
1403+
1404+
@async_client_context.require_replica_set
1405+
async def test_command_cursor_to_list(self):
1406+
c = await self.db.test.aggregate([{"$changeStream": {}}])
1407+
1408+
docs = await c.to_list()
1409+
1410+
self.assertGreaterEqual(len(docs), 0)
1411+
1412+
@async_client_context.require_replica_set
1413+
async def test_command_cursor_to_list_empty(self):
1414+
c = await self.db.does_not_exist.aggregate([{"$changeStream": {}}])
1415+
1416+
docs = await c.to_list()
1417+
1418+
self.assertEqual([], docs)
1419+
13831420

13841421
class TestRawBatchCursor(AsyncIntegrationTest):
13851422
async def test_find_raw(self):

test/test_cursor.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,6 +1371,43 @@ def test_getMore_does_not_send_readPreference(self):
13711371
self.assertEqual("getMore", started[1].command_name)
13721372
self.assertNotIn("$readPreference", started[1].command)
13731373

1374+
@client_context.require_replica_set
1375+
def test_to_list_tailable(self):
1376+
oplog = self.client.local.oplog.rs
1377+
last = oplog.find().sort("$natural", pymongo.DESCENDING).limit(-1).next()
1378+
ts = last["ts"]
1379+
1380+
c = oplog.find(
1381+
{"ts": {"$gte": ts}}, cursor_type=pymongo.CursorType.TAILABLE_AWAIT, oplog_replay=True
1382+
)
1383+
1384+
docs = c.to_list()
1385+
1386+
self.assertGreaterEqual(len(docs), 1)
1387+
1388+
def test_to_list_empty(self):
1389+
c = self.db.does_not_exist.find()
1390+
1391+
docs = c.to_list()
1392+
1393+
self.assertEqual([], docs)
1394+
1395+
@client_context.require_replica_set
1396+
def test_command_cursor_to_list(self):
1397+
c = self.db.test.aggregate([{"$changeStream": {}}])
1398+
1399+
docs = c.to_list()
1400+
1401+
self.assertGreaterEqual(len(docs), 0)
1402+
1403+
@client_context.require_replica_set
1404+
def test_command_cursor_to_list_empty(self):
1405+
c = self.db.does_not_exist.aggregate([{"$changeStream": {}}])
1406+
1407+
docs = c.to_list()
1408+
1409+
self.assertEqual([], docs)
1410+
13741411

13751412
class TestRawBatchCursor(IntegrationTest):
13761413
def test_find_raw(self):

0 commit comments

Comments
 (0)