Skip to content

Commit 6920dfc

Browse files
authored
Merge pull request #74 from encode/fetch_val
Add `fetch_val()` interface
2 parents 26c683e + 3295917 commit 6920dfc

File tree

4 files changed

+48
-14
lines changed

4 files changed

+48
-14
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ rows = await database.fetch_all(query=query)
9797
query = notes.select()
9898
row = await database.fetch_one(query=query)
9999

100+
# Fetch single value, defaults to `column=0`.
101+
query = notes.select()
102+
value = await database.fetch_val(query=query)
103+
100104
# Fetch multiple rows without loading them all into memory at once
101105
query = notes.select()
102106
async for row in database.iterate(query=query):

databases/backends/postgres.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sqlalchemy.engine.interfaces import Dialect
88
from sqlalchemy.sql import ClauseElement
99
from sqlalchemy.sql.schema import Column
10+
from sqlalchemy.types import TypeEngine
1011

1112
from databases.core import DatabaseURL
1213
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
@@ -72,20 +73,27 @@ def __init__(
7273
self._row = row
7374
self._result_columns = result_columns
7475
self._dialect = dialect
75-
self._column_map = {
76-
column_name: (idx, datatype)
77-
for idx, (column_name, _, _, datatype) in enumerate(self._result_columns)
78-
}
79-
self._column_map_full = {
80-
str(column[0]): (idx, datatype)
81-
for idx, (_, _, column, datatype) in enumerate(self._result_columns)
82-
}
76+
self._column_map = (
77+
{}
78+
) # type: typing.Mapping[str, typing.Tuple[int, TypeEngine]]
79+
self._column_map_int = (
80+
{}
81+
) # type: typing.Mapping[int, typing.Tuple[int, TypeEngine]]
82+
self._column_map_full = (
83+
{}
84+
) # type: typing.Mapping[str, typing.Tuple[int, TypeEngine]]
85+
for idx, (column_name, _, column, datatype) in enumerate(self._result_columns):
86+
self._column_map[column_name] = (idx, datatype)
87+
self._column_map_int[idx] = (idx, datatype)
88+
self._column_map_full[str(column[0])] = (idx, datatype)
8389

8490
def __getitem__(self, key: typing.Any) -> typing.Any:
8591
if len(self._column_map) == 0: # raw query
8692
return self._row[tuple(self._row.keys()).index(key)]
8793
elif type(key) is Column:
8894
idx, datatype = self._column_map_full[str(key)]
95+
elif type(key) is int:
96+
idx, datatype = self._column_map_int[key]
8997
else:
9098
idx, datatype = self._column_map[key]
9199
raw = self._row[idx]

databases/core.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from urllib.parse import SplitResult, parse_qsl, urlsplit
77

88
from sqlalchemy import text
9-
from sqlalchemy.engine import RowProxy
109
from sqlalchemy.sql import ClauseElement
1110

1211
from databases.importer import import_from_string
@@ -91,16 +90,25 @@ async def __aexit__(
9190

9291
async def fetch_all(
9392
self, query: typing.Union[ClauseElement, str], values: dict = None
94-
) -> typing.List[RowProxy]:
93+
) -> typing.List[typing.Mapping]:
9594
async with self.connection() as connection:
9695
return await connection.fetch_all(query, values)
9796

9897
async def fetch_one(
9998
self, query: typing.Union[ClauseElement, str], values: dict = None
100-
) -> RowProxy:
99+
) -> typing.Optional[typing.Mapping]:
101100
async with self.connection() as connection:
102101
return await connection.fetch_one(query, values)
103102

103+
async def fetch_val(
104+
self,
105+
query: typing.Union[ClauseElement, str],
106+
values: dict = None,
107+
column: typing.Any = 0,
108+
) -> typing.Any:
109+
async with self.connection() as connection:
110+
return await connection.fetch_val(query, values, column=column)
111+
104112
async def execute(
105113
self, query: typing.Union[ClauseElement, str], values: dict = None
106114
) -> typing.Any:
@@ -115,7 +123,7 @@ async def execute_many(
115123

116124
async def iterate(
117125
self, query: typing.Union[ClauseElement, str], values: dict = None
118-
) -> typing.AsyncGenerator[RowProxy, None]:
126+
) -> typing.AsyncGenerator[typing.Mapping, None]:
119127
async with self.connection() as connection:
120128
async for record in connection.iterate(query, values):
121129
yield record
@@ -167,14 +175,23 @@ async def __aexit__(
167175

168176
async def fetch_all(
169177
self, query: typing.Union[ClauseElement, str], values: dict = None
170-
) -> typing.Any:
178+
) -> typing.List[typing.Mapping]:
171179
return await self._connection.fetch_all(self._build_query(query, values))
172180

173181
async def fetch_one(
174182
self, query: typing.Union[ClauseElement, str], values: dict = None
175-
) -> typing.Any:
183+
) -> typing.Optional[typing.Mapping]:
176184
return await self._connection.fetch_one(self._build_query(query, values))
177185

186+
async def fetch_val(
187+
self,
188+
query: typing.Union[ClauseElement, str],
189+
values: dict = None,
190+
column: typing.Any = 0,
191+
) -> typing.Any:
192+
row = await self._connection.fetch_one(self._build_query(query, values))
193+
return None if row is None else row[column]
194+
178195
async def execute(
179196
self, query: typing.Union[ClauseElement, str], values: dict = None
180197
) -> typing.Any:

tests/test_databases.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ async def test_queries(database_url):
137137
assert result["text"] == "example1"
138138
assert result["completed"] == True
139139

140+
# fetch_val()
141+
query = sqlalchemy.sql.select([notes.c.text])
142+
result = await database.fetch_val(query=query)
143+
assert result == "example1"
144+
140145
# iterate()
141146
query = notes.select()
142147
iterate_results = []

0 commit comments

Comments
 (0)