|
9 | 9 |
|
10 | 10 | from sqlalchemy import text
|
11 | 11 | from sqlalchemy.sql import ClauseElement
|
| 12 | +from sqlalchemy.sql.dml import ValuesBase |
| 13 | +from sqlalchemy.sql.expression import type_coerce |
| 14 | + |
12 | 15 |
|
13 | 16 | from databases.importer import import_from_string
|
14 | 17 | from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
|
@@ -294,11 +297,51 @@ def _build_query(
|
294 | 297 | query = text(query)
|
295 | 298 |
|
296 | 299 | return query.bindparams(**values) if values is not None else query
|
297 |
| - elif values: |
| 300 | + |
| 301 | + # 2 paths where we apply column defaults: |
| 302 | + # - values are supplied (the object must be a ValuesBase) |
| 303 | + # - values is None but the object is a ValuesBase |
| 304 | + if values is not None and not isinstance(query, ValuesBase): |
| 305 | + raise TypeError("values supplied but query doesn't support .values()") |
| 306 | + |
| 307 | + if values is not None or isinstance(query, ValuesBase): |
| 308 | + values = Connection._apply_column_defaults(query, values) |
298 | 309 | return query.values(**values)
|
299 | 310 |
|
300 | 311 | return query
|
301 | 312 |
|
| 313 | + @staticmethod |
| 314 | + def _apply_column_defaults(query: ValuesBase, values: dict = None) -> dict: |
| 315 | + """Add default values from the table of a query.""" |
| 316 | + new_values = {} |
| 317 | + values = values or {} |
| 318 | + |
| 319 | + for column in query.table.c: |
| 320 | + if column.name in values: |
| 321 | + continue |
| 322 | + |
| 323 | + if column.default: |
| 324 | + default = column.default |
| 325 | + |
| 326 | + if default.is_sequence: # pragma: no cover |
| 327 | + # TODO: support sequences |
| 328 | + continue |
| 329 | + elif default.is_callable: |
| 330 | + value = default.arg(FakeExecutionContext()) |
| 331 | + elif default.is_clause_element: # pragma: no cover |
| 332 | + # TODO: implement clause element |
| 333 | + # For this, the _build_query method needs to |
| 334 | + # become an instance method so that it can access |
| 335 | + # self._connection. |
| 336 | + continue |
| 337 | + else: |
| 338 | + value = default.arg |
| 339 | + |
| 340 | + new_values[column.name] = value |
| 341 | + |
| 342 | + new_values.update(values) |
| 343 | + return new_values |
| 344 | + |
302 | 345 |
|
303 | 346 | class Transaction:
|
304 | 347 | def __init__(
|
@@ -489,3 +532,20 @@ def __repr__(self) -> str:
|
489 | 532 |
|
490 | 533 | def __eq__(self, other: typing.Any) -> bool:
|
491 | 534 | return str(self) == str(other)
|
| 535 | + |
| 536 | + |
| 537 | +class FakeExecutionContext: |
| 538 | + """ |
| 539 | + This is an object that raises an error when one of its properties are |
| 540 | + attempted to be accessed. Because we're not _really_ using SQLAlchemy |
| 541 | + (besides using its query builder), we can't pass a real ExecutionContext |
| 542 | + to ColumnDefault objects. This class makes it so that any attempts to |
| 543 | + access the execution context argument by a column default callable |
| 544 | + blows up loudly and clearly. |
| 545 | + """ |
| 546 | + |
| 547 | + def __getattr__(self, _: str) -> typing.NoReturn: # pragma: no cover |
| 548 | + raise NotImplementedError( |
| 549 | + "Databases does not have a real SQLAlchemy ExecutionContext " |
| 550 | + "implementation." |
| 551 | + ) |
0 commit comments