Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion databases/backends/aiopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
cursor.close()

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n: Optional[int] = None

) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
Expand All @@ -195,6 +195,10 @@ async def iterate(
Row._default_key_style,
row,
)
if n is not None:
n -= 1
if n == 0:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use if n <= 0 in case if negative n was passed

break
finally:
cursor.close()

Expand Down
6 changes: 5 additions & 1 deletion databases/backends/asyncmy.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
await cursor.close()

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = None
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
Expand All @@ -185,6 +185,10 @@ async def iterate(
Row._default_key_style,
row,
)
if n is not None:
n -= 1
if n == 0:
break
finally:
await cursor.close()

Expand Down
6 changes: 5 additions & 1 deletion databases/backends/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
await cursor.close()

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = None
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
Expand All @@ -185,6 +185,10 @@ async def iterate(
Row._default_key_style,
row,
)
if n is not None:
n -= 1
if n == 0:
break
finally:
await cursor.close()

Expand Down
7 changes: 5 additions & 2 deletions databases/backends/postgres.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import typing
from collections.abc import Sequence

import asyncpg
from sqlalchemy.dialects.postgresql import pypostgresql
Expand Down Expand Up @@ -227,13 +226,17 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
await self._connection.execute(single_query, *args)

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = None
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = self._compile(query)
column_maps = self._create_column_maps(result_columns)
async for row in self._connection.cursor(query_str, *args):
yield Record(row, result_columns, self._dialect, column_maps)
if n is not None:
n -= 1
if n == 0:
break

def transaction(self) -> TransactionBackend:
return PostgresTransaction(connection=self)
Expand Down
6 changes: 5 additions & 1 deletion databases/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
await self.execute(single_query)

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = None
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
Expand All @@ -155,6 +155,10 @@ async def iterate(
Row._default_key_style,
row,
)
if n is not None:
n -= 1
if n == 0:
break

def transaction(self) -> TransactionBackend:
return SQLiteTransaction(self)
Expand Down
14 changes: 10 additions & 4 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,13 @@ async def execute_many(
return await connection.execute_many(query, values)

async def iterate(
self, query: typing.Union[ClauseElement, str], values: dict = None
self,
query: typing.Union[ClauseElement, str],
values: dict = None,
n: int = None,
) -> typing.AsyncGenerator[typing.Mapping, None]:
async with self.connection() as connection:
async for record in connection.iterate(query, values):
async for record in connection.iterate(query, values, n):
yield record

def _new_connection(self) -> "Connection":
Expand Down Expand Up @@ -307,12 +310,15 @@ async def execute_many(
await self._connection.execute_many(queries)

async def iterate(
self, query: typing.Union[ClauseElement, str], values: dict = None
self,
query: typing.Union[ClauseElement, str],
values: dict = None,
n: int = None,
) -> typing.AsyncGenerator[typing.Any, None]:
built_query = self._build_query(query, values)
async with self.transaction():
async with self._query_lock:
async for record in self._connection.iterate(built_query):
async for record in self._connection.iterate(built_query, n):
yield record

def transaction(
Expand Down
2 changes: 1 addition & 1 deletion databases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
raise NotImplementedError() # pragma: no cover

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = None
) -> typing.AsyncGenerator[typing.Mapping, None]:
raise NotImplementedError() # pragma: no cover
# mypy needs async iterators to contain a `yield`
Expand Down
11 changes: 11 additions & 0 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@ async def test_queries(database_url):
assert iterate_results[2]["text"] == "example3"
assert iterate_results[2]["completed"] == True

# iterate() with custom number of records
query = notes.select()
iterate_results = []
async for result in database.iterate(query=query, n=2):
iterate_results.append(result)
assert len(iterate_results) == 2
assert iterate_results[0]["text"] == "example1"
assert iterate_results[0]["completed"] == True
assert iterate_results[1]["text"] == "example2"
assert iterate_results[1]["completed"] == False


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@mysql_versions
Expand Down