Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 2973437

Browse files
Merge pull request #50 from encode/tighten-interface
Tighten interface
2 parents fbe39bf + 0a21c20 commit 2973437

File tree

6 files changed

+67
-13
lines changed

6 files changed

+67
-13
lines changed

databases/backends/mysql.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ async def release(self) -> None:
8383
await self._database._pool.release(self._connection)
8484
self._connection = None
8585

86-
async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
86+
async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
8787
assert self._connection is not None, "Connection is not acquired"
8888
query, args, context = self._compile(query)
8989
cursor = await self._connection.cursor()
@@ -98,13 +98,15 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
9898
finally:
9999
await cursor.close()
100100

101-
async def fetch_one(self, query: ClauseElement) -> RowProxy:
101+
async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]:
102102
assert self._connection is not None, "Connection is not acquired"
103103
query, args, context = self._compile(query)
104104
cursor = await self._connection.cursor()
105105
try:
106106
await cursor.execute(query, args)
107107
row = await cursor.fetchone()
108+
if row is None:
109+
return None
108110
metadata = ResultMetaData(context, cursor.description)
109111
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
110112
finally:

databases/backends/postgres.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import typing
3+
from collections.abc import Mapping
34

45
import asyncpg
56
from sqlalchemy.dialects.postgresql import pypostgresql
@@ -63,7 +64,7 @@ def connection(self) -> "PostgresConnection":
6364
return PostgresConnection(self, self._dialect)
6465

6566

66-
class Record:
67+
class Record(Mapping):
6768
def __init__(self, row: tuple, result_columns: tuple, dialect: Dialect) -> None:
6869
self._row = row
6970
self._result_columns = result_columns
@@ -86,6 +87,12 @@ def __getitem__(self, key: str) -> typing.Any:
8687
return processor(raw)
8788
return raw
8889

90+
def __iter__(self) -> typing.Iterator:
91+
return iter(self._column_map)
92+
93+
def __len__(self) -> int:
94+
return len(self._column_map)
95+
8996

9097
class PostgresConnection(ConnectionBackend):
9198
def __init__(self, database: PostgresBackend, dialect: Dialect):
@@ -104,16 +111,18 @@ async def release(self) -> None:
104111
self._connection = await self._database._pool.release(self._connection)
105112
self._connection = None
106113

107-
async def fetch_all(self, query: ClauseElement) -> typing.Any:
114+
async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
108115
assert self._connection is not None, "Connection is not acquired"
109116
query, args, result_columns = self._compile(query)
110117
rows = await self._connection.fetch(query, *args)
111118
return [Record(row, result_columns, self._dialect) for row in rows]
112119

113-
async def fetch_one(self, query: ClauseElement) -> typing.Any:
120+
async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]:
114121
assert self._connection is not None, "Connection is not acquired"
115122
query, args, result_columns = self._compile(query)
116123
row = await self._connection.fetchrow(query, *args)
124+
if row is None:
125+
return None
117126
return Record(row, result_columns, self._dialect)
118127

119128
async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any:

databases/backends/sqlite.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def release(self) -> None:
7979
await self._pool.release(self._connection)
8080
self._connection = None
8181

82-
async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
82+
async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
8383
assert self._connection is not None, "Connection is not acquired"
8484
query, args, context = self._compile(query)
8585

@@ -91,16 +91,18 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
9191
for row in rows
9292
]
9393

94-
async def fetch_one(self, query: ClauseElement) -> RowProxy:
94+
async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]:
9595
assert self._connection is not None, "Connection is not acquired"
9696
query, args, context = self._compile(query)
9797

9898
async with self._connection.execute(query, args) as cursor:
9999
row = await cursor.fetchone()
100+
if row is None:
101+
return None
100102
metadata = ResultMetaData(context, cursor.description)
101103
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
102104

103-
async def execute(self, query: ClauseElement, values: dict = None) -> None:
105+
async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any:
104106
assert self._connection is not None, "Connection is not acquired"
105107
if values is not None:
106108
query = query.values(values)

databases/interfaces.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import typing
22

3-
from sqlalchemy.engine import RowProxy
43
from sqlalchemy.sql import ClauseElement
54

65

@@ -22,21 +21,21 @@ async def acquire(self) -> None:
2221
async def release(self) -> None:
2322
raise NotImplementedError() # pragma: no cover
2423

25-
async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
24+
async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
2625
raise NotImplementedError() # pragma: no cover
2726

28-
async def fetch_one(self, query: ClauseElement) -> RowProxy:
27+
async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]:
2928
raise NotImplementedError() # pragma: no cover
3029

31-
async def execute(self, query: ClauseElement, values: dict = None) -> None:
30+
async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any:
3231
raise NotImplementedError() # pragma: no cover
3332

3433
async def execute_many(self, query: ClauseElement, values: list) -> None:
3534
raise NotImplementedError() # pragma: no cover
3635

3736
async def iterate(
3837
self, query: ClauseElement
39-
) -> typing.AsyncGenerator[RowProxy, None]:
38+
) -> typing.AsyncGenerator[typing.Mapping, None]:
4039
raise NotImplementedError() # pragma: no cover
4140
# mypy needs async iterators to contain a `yield`
4241
# https://github.com/python/mypy/issues/5385#issuecomment-407281656

scripts/lint

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ set -x
1010
${PREFIX}autoflake --in-place --recursive databases tests
1111
${PREFIX}black databases tests
1212
${PREFIX}isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --apply databases tests
13+
${PREFIX}mypy databases --ignore-missing-imports --disallow-untyped-defs

tests/test_databases.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,47 @@ async def test_queries(database_url):
151151
assert iterate_results[2]["completed"] == True
152152

153153

154+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
155+
@async_adapter
156+
async def test_results_support_mapping_interface(database_url):
157+
"""
158+
Casting results to a dict should work, since the interface defines them
159+
as supporting the mapping interface.
160+
"""
161+
async with Database(database_url) as database:
162+
async with database.transaction(force_rollback=True):
163+
# execute()
164+
query = notes.insert()
165+
values = {"text": "example1", "completed": True}
166+
await database.execute(query, values)
167+
168+
# fetch_all()
169+
query = notes.select()
170+
results = await database.fetch_all(query=query)
171+
results_as_dicts = [dict(item) for item in results]
172+
173+
assert len(results[0]) == 3
174+
assert len(results_as_dicts[0]) == 3
175+
176+
assert isinstance(results_as_dicts[0]["id"], int)
177+
assert results_as_dicts[0]["text"] == "example1"
178+
assert results_as_dicts[0]["completed"] == True
179+
180+
181+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
182+
@async_adapter
183+
async def test_fetch_one_returning_no_results(database_url):
184+
"""
185+
fetch_one should return `None` when no results match.
186+
"""
187+
async with Database(database_url) as database:
188+
async with database.transaction(force_rollback=True):
189+
# fetch_all()
190+
query = notes.select()
191+
result = await database.fetch_one(query=query)
192+
assert result is None
193+
194+
154195
@pytest.mark.parametrize("database_url", DATABASE_URLS)
155196
@async_adapter
156197
async def test_execute_return_val(database_url):

0 commit comments

Comments
 (0)