Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 5fdf35a

Browse files
committed
Added support for raw queries using sqlalchemy.text()
1 parent aec964c commit 5fdf35a

File tree

6 files changed

+113
-42
lines changed

6 files changed

+113
-42
lines changed

databases/backends/mysql.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,8 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mappin
112112
finally:
113113
await cursor.close()
114114

115-
async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any:
115+
async def execute(self, query: ClauseElement) -> typing.Any:
116116
assert self._connection is not None, "Connection is not acquired"
117-
if values is not None:
118-
query = query.values(values)
119117
query, args, context = self._compile(query)
120118
cursor = await self._connection.cursor()
121119
try:
@@ -124,12 +122,11 @@ async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any
124122
finally:
125123
await cursor.close()
126124

127-
async def execute_many(self, query: ClauseElement, values: list) -> None:
125+
async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
128126
assert self._connection is not None, "Connection is not acquired"
129127
cursor = await self._connection.cursor()
130128
try:
131-
for item in values:
132-
single_query = query.values(item)
129+
for single_query in queries:
133130
single_query, args, context = self._compile(single_query)
134131
await cursor.execute(single_query, args)
135132
finally:

databases/backends/postgres.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def connection(self) -> "PostgresConnection":
6666

6767

6868
class Record(Mapping):
69-
def __init__(self, row: tuple, result_columns: tuple, dialect: Dialect) -> None:
69+
def __init__(
70+
self, row: asyncpg.Record, result_columns: tuple, dialect: Dialect
71+
) -> None:
7072
self._row = row
7173
self._result_columns = result_columns
7274
self._dialect = dialect
@@ -80,11 +82,14 @@ def __init__(self, row: tuple, result_columns: tuple, dialect: Dialect) -> None:
8082
}
8183

8284
def __getitem__(self, key: typing.Any) -> typing.Any:
83-
if type(key) is Column:
84-
idx, datatype = self._column_map_full[str(key)]
85-
else:
86-
idx, datatype = self._column_map[key]
87-
raw = self._row[idx]
85+
if len(self._column_map) == 0: # raw query
86+
return self._row[tuple(self._row.keys()).index(key)]
87+
if len(self._column_map) > 0:
88+
if type(key) is Column:
89+
idx, datatype = self._column_map_full[str(key)]
90+
else:
91+
idx, datatype = self._column_map[key]
92+
raw = self._row[idx]
8893
try:
8994
processor = _result_processors[datatype]
9095
except KeyError:
@@ -133,20 +138,17 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mappin
133138
return None
134139
return Record(row, result_columns, self._dialect)
135140

136-
async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any:
141+
async def execute(self, query: ClauseElement) -> typing.Any:
137142
assert self._connection is not None, "Connection is not acquired"
138-
if values is not None:
139-
query = query.values(values)
140143
query, args, result_columns = self._compile(query)
141144
return await self._connection.fetchval(query, *args)
142145

143-
async def execute_many(self, query: ClauseElement, values: list) -> None:
146+
async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
144147
assert self._connection is not None, "Connection is not acquired"
145148
# asyncpg uses prepared statements under the hood, so we just
146149
# loop through multiple executes here, which should all end up
147150
# using the same prepared statement.
148-
for item in values:
149-
single_query = query.values(item)
151+
for single_query in queries:
150152
single_query, args, result_columns = self._compile(single_query)
151153
await self._connection.execute(single_query, *args)
152154

databases/backends/sqlite.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,17 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mappin
102102
metadata = ResultMetaData(context, cursor.description)
103103
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
104104

105-
async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any:
105+
async def execute(self, query: ClauseElement) -> typing.Any:
106106
assert self._connection is not None, "Connection is not acquired"
107-
if values is not None:
108-
query = query.values(values)
109107
query, args, context = self._compile(query)
110108
cursor = await self._connection.execute(query, args)
111109
await cursor.close()
112110
return cursor.lastrowid
113111

114-
async def execute_many(self, query: ClauseElement, values: list) -> None:
112+
async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
115113
assert self._connection is not None, "Connection is not acquired"
116-
for value in values:
117-
await self.execute(query, value)
114+
for single_query in queries:
115+
await self.execute(single_query)
118116

119117
async def iterate(
120118
self, query: ClauseElement

databases/core.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from types import TracebackType
66
from urllib.parse import SplitResult, parse_qsl, urlsplit
77

8+
from sqlalchemy import text
89
from sqlalchemy.engine import RowProxy
910
from sqlalchemy.sql import ClauseElement
1011

@@ -88,27 +89,36 @@ async def __aexit__(
8889
) -> None:
8990
await self.disconnect()
9091

91-
async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
92+
async def fetch_all(
93+
self, query: typing.Union[ClauseElement, str], values: dict = None
94+
) -> typing.List[RowProxy]:
9295
async with self.connection() as connection:
93-
return await connection.fetch_all(query=query)
96+
return await connection.fetch_all(query=self._build_query(query, values))
9497

95-
async def fetch_one(self, query: ClauseElement) -> RowProxy:
98+
async def fetch_one(
99+
self, query: typing.Union[ClauseElement, str], values: dict = None
100+
) -> RowProxy:
96101
async with self.connection() as connection:
97-
return await connection.fetch_one(query=query)
102+
return await connection.fetch_one(query=self._build_query(query, values))
98103

99-
async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any:
104+
async def execute(
105+
self, query: typing.Union[ClauseElement, str], values: dict = None
106+
) -> typing.Any:
100107
async with self.connection() as connection:
101-
return await connection.execute(query=query, values=values)
108+
return await connection.execute(self._build_query(query, values))
102109

103-
async def execute_many(self, query: ClauseElement, values: list) -> None:
110+
async def execute_many(
111+
self, query: typing.Union[ClauseElement, str], values: list
112+
) -> None:
104113
async with self.connection() as connection:
105-
return await connection.execute_many(query=query, values=values)
114+
queries = [self._build_query(query, values_set) for values_set in values]
115+
return await connection.execute_many(queries)
106116

107117
async def iterate(
108-
self, query: ClauseElement
118+
self, query: typing.Union[ClauseElement, str], values: dict = None
109119
) -> typing.AsyncGenerator[RowProxy, None]:
110120
async with self.connection() as connection:
111-
async for record in connection.iterate(query):
121+
async for record in connection.iterate(self._build_query(query, values)):
112122
yield record
113123

114124
def connection(self) -> "Connection":
@@ -125,6 +135,19 @@ def connection(self) -> "Connection":
125135
def transaction(self, *, force_rollback: bool = False) -> "Transaction":
126136
return self.connection().transaction(force_rollback=force_rollback)
127137

138+
@staticmethod
139+
def _build_query(
140+
query: typing.Union[ClauseElement, str], values: dict = None
141+
) -> ClauseElement:
142+
if isinstance(query, str):
143+
query = text(query)
144+
145+
return query.bindparams(**values) if values is not None else query
146+
elif values:
147+
return query.values(**values)
148+
149+
return query
150+
128151

129152
class Connection:
130153
def __init__(self, backend: DatabaseBackend) -> None:
@@ -162,11 +185,11 @@ async def fetch_all(self, query: ClauseElement) -> typing.Any:
162185
async def fetch_one(self, query: ClauseElement) -> typing.Any:
163186
return await self._connection.fetch_one(query=query)
164187

165-
async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any:
166-
return await self._connection.execute(query, values)
188+
async def execute(self, query: ClauseElement) -> typing.Any:
189+
return await self._connection.execute(query)
167190

168-
async def execute_many(self, query: ClauseElement, values: list) -> None:
169-
await self._connection.execute_many(query, values)
191+
async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
192+
await self._connection.execute_many(queries)
170193

171194
async def iterate(
172195
self, query: ClauseElement

databases/interfaces.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
2727
async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]:
2828
raise NotImplementedError() # pragma: no cover
2929

30-
async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any:
30+
async def execute(self, query: ClauseElement) -> typing.Any:
3131
raise NotImplementedError() # pragma: no cover
3232

33-
async def execute_many(self, query: ClauseElement, values: list) -> None:
33+
async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
3434
raise NotImplementedError() # pragma: no cover
3535

3636
async def iterate(

tests/test_databases.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def run_sync(*args, **kwargs):
103103
async def test_queries(database_url):
104104
"""
105105
Test that the basic `execute()`, `execute_many()`, `fetch_all()``, and
106-
`fetch_one()` interfaces are all supported.
106+
`fetch_one()` interfaces are all supported (using SQLAlchemy core).
107107
"""
108108
async with Database(database_url) as database:
109109
async with database.transaction(force_rollback=True):
@@ -151,6 +151,57 @@ async def test_queries(database_url):
151151
assert iterate_results[2]["completed"] == True
152152

153153

154+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
155+
@async_adapter
156+
async def test_queries_raw(database_url):
157+
"""
158+
Test that the basic `execute()`, `execute_many()`, `fetch_all()``, and
159+
`fetch_one()` interfaces are all supported (raw queries).
160+
"""
161+
async with Database(database_url) as database:
162+
async with database.transaction(force_rollback=True):
163+
# execute()
164+
query = "INSERT INTO notes(text, completed) VALUES (:text, :completed)"
165+
values = {"text": "example1", "completed": True}
166+
await database.execute(query, values)
167+
168+
# execute_many()
169+
query = "INSERT INTO notes(text, completed) VALUES (:text, :completed)"
170+
values = [
171+
{"text": "example2", "completed": False},
172+
{"text": "example3", "completed": True},
173+
]
174+
await database.execute_many(query, values)
175+
176+
# fetch_all()
177+
query = "SELECT * FROM notes WHERE completed = :completed"
178+
results = await database.fetch_all(query=query, values={"completed": True})
179+
assert len(results) == 2
180+
assert results[0]["text"] == "example1"
181+
assert results[0]["completed"] == True
182+
assert results[1]["text"] == "example3"
183+
assert results[1]["completed"] == True
184+
185+
# fetch_one()
186+
query = "SELECT * FROM notes WHERE completed = :completed"
187+
result = await database.fetch_one(query=query, values={"completed": False})
188+
assert result["text"] == "example2"
189+
assert result["completed"] == False
190+
191+
# iterate()
192+
query = "SELECT * FROM notes"
193+
iterate_results = []
194+
async for result in database.iterate(query=query):
195+
iterate_results.append(result)
196+
assert len(iterate_results) == 3
197+
assert iterate_results[0]["text"] == "example1"
198+
assert iterate_results[0]["completed"] == True
199+
assert iterate_results[1]["text"] == "example2"
200+
assert iterate_results[1]["completed"] == False
201+
assert iterate_results[2]["text"] == "example3"
202+
assert iterate_results[2]["completed"] == True
203+
204+
154205
@pytest.mark.parametrize("database_url", DATABASE_URLS)
155206
@async_adapter
156207
async def test_results_support_mapping_interface(database_url):
@@ -234,8 +285,8 @@ async def test_execute_return_val(database_url):
234285
query = notes.insert()
235286
values = {"text": "example1", "completed": True}
236287
pk = await database.execute(query, values)
237-
238288
assert isinstance(pk, int)
289+
239290
query = notes.select().where(notes.c.id == pk)
240291
result = await database.fetch_one(query)
241292
assert result["text"] == "example1"

0 commit comments

Comments
 (0)