Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 4088e42

Browse files
Yureienvmarkovtsev
andauthored
Fixed SQLAlchemy DDL statements (#226)
* Fixed SQLAlchemy DDL statements * Using black for code formatting * Moved implementation to higher-level and created tests Co-authored-by: Vadim Markovtsev <[email protected]>
1 parent 868132b commit 4088e42

File tree

5 files changed

+85
-41
lines changed

5 files changed

+85
-41
lines changed

databases/backends/aiopg.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
1111
from sqlalchemy.engine.result import ResultMetaData, RowProxy
1212
from sqlalchemy.sql import ClauseElement
13+
from sqlalchemy.sql.ddl import DDLElement
1314
from sqlalchemy.types import TypeEngine
1415

1516
from databases.core import DatabaseURL
@@ -181,18 +182,23 @@ def _compile(
181182
self, query: ClauseElement
182183
) -> typing.Tuple[str, dict, CompilationContext]:
183184
compiled = query.compile(dialect=self._dialect)
184-
args = compiled.construct_params()
185-
for key, val in args.items():
186-
if key in compiled._bind_processors:
187-
args[key] = compiled._bind_processors[key](val)
188185

189186
execution_context = self._dialect.execution_ctx_cls()
190187
execution_context.dialect = self._dialect
191-
execution_context.result_column_struct = (
192-
compiled._result_columns,
193-
compiled._ordered_columns,
194-
compiled._textual_ordered_columns,
195-
)
188+
189+
if not isinstance(query, DDLElement):
190+
args = compiled.construct_params()
191+
for key, val in args.items():
192+
if key in compiled._bind_processors:
193+
args[key] = compiled._bind_processors[key](val)
194+
195+
execution_context.result_column_struct = (
196+
compiled._result_columns,
197+
compiled._ordered_columns,
198+
compiled._textual_ordered_columns,
199+
)
200+
else:
201+
args = {}
196202

197203
logger.debug("Query: %s\nArgs: %s", compiled.string, args)
198204
return compiled.string, args, CompilationContext(execution_context)

databases/backends/mysql.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
99
from sqlalchemy.engine.result import ResultMetaData, RowProxy
1010
from sqlalchemy.sql import ClauseElement
11+
from sqlalchemy.sql.ddl import DDLElement
1112
from sqlalchemy.types import TypeEngine
1213

1314
from databases.core import LOG_EXTRA, DatabaseURL
@@ -171,18 +172,23 @@ def _compile(
171172
self, query: ClauseElement
172173
) -> typing.Tuple[str, dict, CompilationContext]:
173174
compiled = query.compile(dialect=self._dialect)
174-
args = compiled.construct_params()
175-
for key, val in args.items():
176-
if key in compiled._bind_processors:
177-
args[key] = compiled._bind_processors[key](val)
178175

179176
execution_context = self._dialect.execution_ctx_cls()
180177
execution_context.dialect = self._dialect
181-
execution_context.result_column_struct = (
182-
compiled._result_columns,
183-
compiled._ordered_columns,
184-
compiled._textual_ordered_columns,
185-
)
178+
179+
if not isinstance(query, DDLElement):
180+
args = compiled.construct_params()
181+
for key, val in args.items():
182+
if key in compiled._bind_processors:
183+
args[key] = compiled._bind_processors[key](val)
184+
185+
execution_context.result_column_struct = (
186+
compiled._result_columns,
187+
compiled._ordered_columns,
188+
compiled._textual_ordered_columns,
189+
)
190+
else:
191+
args = {}
186192

187193
query_message = compiled.string.replace(" \n", " ").replace("\n", " ")
188194
logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA)

databases/backends/postgres.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sqlalchemy.dialects.postgresql import pypostgresql
77
from sqlalchemy.engine.interfaces import Dialect
88
from sqlalchemy.sql import ClauseElement
9+
from sqlalchemy.sql.ddl import DDLElement
910
from sqlalchemy.sql.schema import Column
1011
from sqlalchemy.types import TypeEngine
1112

@@ -209,24 +210,32 @@ def transaction(self) -> TransactionBackend:
209210

210211
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
211212
compiled = query.compile(dialect=self._dialect)
212-
compiled_params = sorted(compiled.params.items())
213213

214-
mapping = {
215-
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
216-
}
217-
compiled_query = compiled.string % mapping
214+
if not isinstance(query, DDLElement):
215+
compiled_params = sorted(compiled.params.items())
218216

219-
processors = compiled._bind_processors
220-
args = [
221-
processors[key](val) if key in processors else val
222-
for key, val in compiled_params
223-
]
217+
mapping = {
218+
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
219+
}
220+
compiled_query = compiled.string % mapping
221+
222+
processors = compiled._bind_processors
223+
args = [
224+
processors[key](val) if key in processors else val
225+
for key, val in compiled_params
226+
]
227+
228+
result_map = compiled._result_columns
229+
else:
230+
compiled_query = compiled.string
231+
args = []
232+
result_map = None
224233

225234
query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
226235
logger.debug(
227236
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
228237
)
229-
return compiled_query, args, compiled._result_columns
238+
return compiled_query, args, result_map
230239

231240
@staticmethod
232241
def _create_column_maps(

databases/backends/sqlite.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
88
from sqlalchemy.engine.result import ResultMetaData, RowProxy
99
from sqlalchemy.sql import ClauseElement
10+
from sqlalchemy.sql.ddl import DDLElement
1011
from sqlalchemy.types import TypeEngine
1112

1213
from databases.core import LOG_EXTRA, DatabaseURL
@@ -139,21 +140,25 @@ def _compile(
139140
self, query: ClauseElement
140141
) -> typing.Tuple[str, list, CompilationContext]:
141142
compiled = query.compile(dialect=self._dialect)
142-
args = []
143-
for key, raw_val in compiled.construct_params().items():
144-
if key in compiled._bind_processors:
145-
val = compiled._bind_processors[key](raw_val)
146-
else:
147-
val = raw_val
148-
args.append(val)
149143

150144
execution_context = self._dialect.execution_ctx_cls()
151145
execution_context.dialect = self._dialect
152-
execution_context.result_column_struct = (
153-
compiled._result_columns,
154-
compiled._ordered_columns,
155-
compiled._textual_ordered_columns,
156-
)
146+
147+
args = []
148+
149+
if not isinstance(query, DDLElement):
150+
for key, raw_val in compiled.construct_params().items():
151+
if key in compiled._bind_processors:
152+
val = compiled._bind_processors[key](raw_val)
153+
else:
154+
val = raw_val
155+
args.append(val)
156+
157+
execution_context.result_column_struct = (
158+
compiled._result_columns,
159+
compiled._ordered_columns,
160+
compiled._textual_ordered_columns,
161+
)
157162

158163
query_message = compiled.string.replace(" \n", " ").replace("\n", " ")
159164
logger.debug(

tests/test_databases.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,24 @@ async def test_queries_raw(database_url):
236236
assert iterate_results[2]["completed"] == True
237237

238238

239+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
240+
@async_adapter
241+
async def test_ddl_queries(database_url):
242+
"""
243+
Test that the built-in DDL elements such as `DropTable()`,
244+
`CreateTable()` are supported (using SQLAlchemy core).
245+
"""
246+
async with Database(database_url) as database:
247+
async with database.transaction(force_rollback=True):
248+
# DropTable()
249+
query = sqlalchemy.schema.DropTable(notes)
250+
await database.execute(query)
251+
252+
# CreateTable()
253+
query = sqlalchemy.schema.CreateTable(notes)
254+
await database.execute(query)
255+
256+
239257
@pytest.mark.parametrize("database_url", DATABASE_URLS)
240258
@async_adapter
241259
async def test_results_support_mapping_interface(database_url):

0 commit comments

Comments
 (0)