Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 0fc16b2

Browse files
authored
Fix some column types being parsed twice (#582)
* Fix JSON and enum type columns * Add time to reparsing check, add date and time tests * Make processed types inclusive rather than exclusive, limit to just DIALECT_EXCLUDE
1 parent 1e40ad1 commit 0fc16b2

File tree

2 files changed

+143
-8
lines changed

2 files changed

+143
-8
lines changed

databases/backends/common/records.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import json
1+
import enum
22
import typing
3-
from datetime import date, datetime
3+
from datetime import date, datetime, time
44

55
from sqlalchemy.engine.interfaces import Dialect
66
from sqlalchemy.engine.row import Row as SQLRow
77
from sqlalchemy.sql.compiler import _CompileLabel
88
from sqlalchemy.sql.schema import Column
9+
from sqlalchemy.sql.sqltypes import JSON
910
from sqlalchemy.types import TypeEngine
1011

1112
from databases.interfaces import Record as RecordInterface
@@ -62,12 +63,10 @@ def __getitem__(self, key: typing.Any) -> typing.Any:
6263
raw = self._row[idx]
6364
processor = datatype._cached_result_processor(self._dialect, None)
6465

65-
if self._dialect.name not in DIALECT_EXCLUDE:
66-
if isinstance(raw, dict):
67-
raw = json.dumps(raw)
66+
if self._dialect.name in DIALECT_EXCLUDE:
67+
if processor is not None and isinstance(raw, (int, str, float)):
68+
return processor(raw)
6869

69-
if processor is not None and (not isinstance(raw, (datetime, date))):
70-
return processor(raw)
7170
return raw
7271

7372
def __iter__(self) -> typing.Iterator:

tests/test_databases.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import datetime
33
import decimal
4+
import enum
45
import functools
56
import gc
67
import itertools
@@ -55,6 +56,47 @@ def process_result_value(self, value, dialect):
5556
sqlalchemy.Column("published", sqlalchemy.DateTime),
5657
)
5758

59+
# Used to test Date
60+
events = sqlalchemy.Table(
61+
"events",
62+
metadata,
63+
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
64+
sqlalchemy.Column("date", sqlalchemy.Date),
65+
)
66+
67+
68+
# Used to test Time
69+
daily_schedule = sqlalchemy.Table(
70+
"daily_schedule",
71+
metadata,
72+
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
73+
sqlalchemy.Column("time", sqlalchemy.Time),
74+
)
75+
76+
77+
class TshirtSize(enum.Enum):
78+
SMALL = "SMALL"
79+
MEDIUM = "MEDIUM"
80+
LARGE = "LARGE"
81+
XL = "XL"
82+
83+
84+
class TshirtColor(enum.Enum):
85+
BLUE = 0
86+
GREEN = 1
87+
YELLOW = 2
88+
RED = 3
89+
90+
91+
# Used to test Enum
92+
tshirt_size = sqlalchemy.Table(
93+
"tshirt_size",
94+
metadata,
95+
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
96+
sqlalchemy.Column("size", sqlalchemy.Enum(TshirtSize)),
97+
sqlalchemy.Column("color", sqlalchemy.Enum(TshirtColor)),
98+
)
99+
58100
# Used to test JSON
59101
session = sqlalchemy.Table(
60102
"session",
@@ -928,6 +970,52 @@ async def test_datetime_field(database_url):
928970
assert results[0]["published"] == now
929971

930972

973+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
974+
@async_adapter
975+
async def test_date_field(database_url):
976+
"""
977+
Test Date columns, to ensure records are coerced to/from proper Python types.
978+
"""
979+
980+
async with Database(database_url) as database:
981+
async with database.transaction(force_rollback=True):
982+
now = datetime.date.today()
983+
984+
# execute()
985+
query = events.insert()
986+
values = {"date": now}
987+
await database.execute(query, values)
988+
989+
# fetch_all()
990+
query = events.select()
991+
results = await database.fetch_all(query=query)
992+
assert len(results) == 1
993+
assert results[0]["date"] == now
994+
995+
996+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
997+
@async_adapter
998+
async def test_time_field(database_url):
999+
"""
1000+
Test Time columns, to ensure records are coerced to/from proper Python types.
1001+
"""
1002+
1003+
async with Database(database_url) as database:
1004+
async with database.transaction(force_rollback=True):
1005+
now = datetime.datetime.now().time().replace(microsecond=0)
1006+
1007+
# execute()
1008+
query = daily_schedule.insert()
1009+
values = {"time": now}
1010+
await database.execute(query, values)
1011+
1012+
# fetch_all()
1013+
query = daily_schedule.select()
1014+
results = await database.fetch_all(query=query)
1015+
assert len(results) == 1
1016+
assert results[0]["time"] == now
1017+
1018+
9311019
@pytest.mark.parametrize("database_url", DATABASE_URLS)
9321020
@async_adapter
9331021
async def test_decimal_field(database_url):
@@ -957,7 +1045,32 @@ async def test_decimal_field(database_url):
9571045

9581046
@pytest.mark.parametrize("database_url", DATABASE_URLS)
9591047
@async_adapter
960-
async def test_json_field(database_url):
1048+
async def test_enum_field(database_url):
1049+
"""
1050+
Test enum columns, to ensure correct cross-database support.
1051+
"""
1052+
1053+
async with Database(database_url) as database:
1054+
async with database.transaction(force_rollback=True):
1055+
# execute()
1056+
size = TshirtSize.SMALL
1057+
color = TshirtColor.GREEN
1058+
values = {"size": size, "color": color}
1059+
query = tshirt_size.insert()
1060+
await database.execute(query, values)
1061+
1062+
# fetch_all()
1063+
query = tshirt_size.select()
1064+
results = await database.fetch_all(query=query)
1065+
1066+
assert len(results) == 1
1067+
assert results[0]["size"] == size
1068+
assert results[0]["color"] == color
1069+
1070+
1071+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
1072+
@async_adapter
1073+
async def test_json_dict_field(database_url):
9611074
"""
9621075
Test JSON columns, to ensure correct cross-database support.
9631076
"""
@@ -978,6 +1091,29 @@ async def test_json_field(database_url):
9781091
assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1}
9791092

9801093

1094+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
1095+
@async_adapter
1096+
async def test_json_list_field(database_url):
1097+
"""
1098+
Test JSON columns, to ensure correct cross-database support.
1099+
"""
1100+
1101+
async with Database(database_url) as database:
1102+
async with database.transaction(force_rollback=True):
1103+
# execute()
1104+
data = ["lemon", "raspberry", "lime", "pumice"]
1105+
values = {"data": data}
1106+
query = session.insert()
1107+
await database.execute(query, values)
1108+
1109+
# fetch_all()
1110+
query = session.select()
1111+
results = await database.fetch_all(query=query)
1112+
1113+
assert len(results) == 1
1114+
assert results[0]["data"] == ["lemon", "raspberry", "lime", "pumice"]
1115+
1116+
9811117
@pytest.mark.parametrize("database_url", DATABASE_URLS)
9821118
@async_adapter
9831119
async def test_custom_field(database_url):

0 commit comments

Comments
 (0)