Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit dcf3378

Browse files
Implement column defaults for INSERT/UPDATE
1 parent 79e491c commit dcf3378

File tree

2 files changed

+101
-1
lines changed

2 files changed

+101
-1
lines changed

databases/core.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
from sqlalchemy import text
1111
from sqlalchemy.sql import ClauseElement
12+
from sqlalchemy.sql.dml import ValuesBase
13+
from sqlalchemy.sql.expression import type_coerce
14+
1215

1316
from databases.importer import import_from_string
1417
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
@@ -294,11 +297,51 @@ def _build_query(
294297
query = text(query)
295298

296299
return query.bindparams(**values) if values is not None else query
297-
elif values:
300+
301+
# 2 paths where we apply column defaults:
302+
# - values are supplied (the object must be a ValuesBase)
303+
# - values is None but the object is a ValuesBase
304+
if values is not None and not isinstance(query, ValuesBase):
305+
raise TypeError("values supplied but query doesn't support .values()")
306+
307+
if values is not None or isinstance(query, ValuesBase):
308+
values = Connection._apply_column_defaults(query, values)
298309
return query.values(**values)
299310

300311
return query
301312

313+
@staticmethod
314+
def _apply_column_defaults(query: ValuesBase, values: dict = None) -> dict:
315+
"""Add default values from the table of a query."""
316+
new_values = {}
317+
values = values or {}
318+
319+
for column in query.table.c:
320+
if column.name in values:
321+
continue
322+
323+
if column.default:
324+
default = column.default
325+
326+
if default.is_sequence: # pragma: no cover
327+
# TODO: support sequences
328+
continue
329+
elif default.is_callable:
330+
value = default.arg(FakeExecutionContext())
331+
elif default.is_clause_element: # pragma: no cover
332+
# TODO: implement clause element
333+
# For this, the _build_query method needs to
334+
# become an instance method so that it can access
335+
# self._connection.
336+
continue
337+
else:
338+
value = default.arg
339+
340+
new_values[column.name] = value
341+
342+
new_values.update(values)
343+
return new_values
344+
302345

303346
class Transaction:
304347
def __init__(
@@ -489,3 +532,20 @@ def __repr__(self) -> str:
489532

490533
def __eq__(self, other: typing.Any) -> bool:
491534
return str(self) == str(other)
535+
536+
537+
class FakeExecutionContext:
538+
"""
539+
This is an object that raises an error when one of its properties are
540+
attempted to be accessed. Because we're not _really_ using SQLAlchemy
541+
(besides using its query builder), we can't pass a real ExecutionContext
542+
to ColumnDefault objects. This class makes it so that any attempts to
543+
access the execution context argument by a column default callable
544+
blows up loudly and clearly.
545+
"""
546+
547+
def __getattr__(self, _: str) -> typing.NoReturn: # pragma: no cover
548+
raise NotImplementedError(
549+
"Databases does not have a real SQLAlchemy ExecutionContext "
550+
"implementation."
551+
)

tests/test_databases.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ def process_result_value(self, value, dialect):
7070
sqlalchemy.Column("price", sqlalchemy.Numeric(precision=30, scale=20)),
7171
)
7272

73+
# Used to test column default values
74+
timestamps = sqlalchemy.Table(
75+
"timestamps",
76+
metadata,
77+
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
78+
sqlalchemy.Column(
79+
"timestamp", sqlalchemy.DateTime, default=datetime.datetime.now, nullable=False
80+
),
81+
sqlalchemy.Column("priority", sqlalchemy.Integer, default=0, nullable=False),
82+
)
83+
7384

7485
@pytest.fixture(autouse=True, scope="module")
7586
def create_test_database():
@@ -925,3 +936,32 @@ async def test_column_names(database_url, select_query):
925936
assert sorted(results[0].keys()) == ["completed", "id", "text"]
926937
assert results[0]["text"] == "example1"
927938
assert results[0]["completed"] == True
939+
940+
941+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
942+
@async_adapter
943+
async def test_column_defaults(database_url):
944+
"""
945+
Test correct usage of column defaults.
946+
"""
947+
async with Database(database_url) as database:
948+
async with database.transaction(force_rollback=True):
949+
# with just defaults
950+
query = timestamps.insert()
951+
await database.execute(query)
952+
results = await database.fetch_all(query=timestamps.select())
953+
assert len(results) == 1
954+
await database.execute(timestamps.delete())
955+
956+
# with default value overridden
957+
dt = datetime.datetime.now() - datetime.timedelta(seconds=10)
958+
values = {"timestamp": dt}
959+
await database.execute(query, values)
960+
results = await database.fetch_all(timestamps.select())
961+
assert len(results) == 1
962+
assert results[0]["timestamp"] == dt
963+
964+
# testing invalid passing of values with non-ValuesBase
965+
# argument
966+
with pytest.raises(TypeError, match=r".*support \.values\(\).*"):
967+
await database.execute(timestamps.select(), {})

0 commit comments

Comments
 (0)