Skip to content

Commit e4762d9

Browse files
Mikkgnmikkegn
andauthored
added support for sqlalchemy default parameters #455 (#456)
Co-authored-by: mikkegn <[email protected]>
1 parent abef958 commit e4762d9

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed

aiomysql/sa/engine.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,29 @@
1212

1313
try:
1414
from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql
15+
from sqlalchemy.dialects.mysql.mysqldb import MySQLCompiler_mysqldb
1516
except ImportError: # pragma: no cover
1617
raise ImportError('aiomysql.sa requires sqlalchemy')
1718

1819

20+
class MySQLCompiler_pymysql(MySQLCompiler_mysqldb):
21+
def construct_params(self, params=None, _group_number=None, _check=True):
22+
pd = super().construct_params(params, _group_number, _check)
23+
24+
for column in self.prefetch:
25+
pd[column.key] = self._exec_default(column.default)
26+
27+
return pd
28+
29+
def _exec_default(self, default):
30+
if default.is_callable:
31+
return default.arg(self.dialect)
32+
else:
33+
return default.arg
34+
35+
1936
_dialect = MySQLDialect_pymysql(paramstyle='pyformat')
37+
_dialect.statement_compiler = MySQLCompiler_pymysql
2038
_dialect.default_paramstyle = 'pyformat'
2139

2240

tests/sa/test_sa_default.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import datetime
2+
3+
import pytest
4+
from sqlalchemy import MetaData, Table, Column, Integer, String
5+
from sqlalchemy import func, DateTime, Boolean
6+
7+
from aiomysql import sa
8+
9+
meta = MetaData()
10+
table = Table('sa_tbl_default_test', meta,
11+
Column('id', Integer, nullable=False, primary_key=True),
12+
Column('string_length', Integer,
13+
default=func.length('qwerty')),
14+
Column('number', Integer, default=100, nullable=False),
15+
Column('description', String(255), nullable=False,
16+
default='default test'),
17+
Column('created_at', DateTime,
18+
default=datetime.datetime.now),
19+
Column('enabled', Boolean, default=True))
20+
21+
22+
@pytest.fixture()
23+
def make_engine(mysql_params, connection):
24+
async def _make_engine(**kwargs):
25+
return (await sa.create_engine(db=mysql_params['db'],
26+
user=mysql_params['user'],
27+
password=mysql_params['password'],
28+
host=mysql_params['host'],
29+
port=mysql_params['port'],
30+
minsize=10,
31+
**kwargs))
32+
33+
return _make_engine
34+
35+
36+
async def start(engine):
37+
async with engine.acquire() as conn:
38+
await conn.execute("DROP TABLE IF EXISTS sa_tbl_default_test")
39+
await conn.execute("CREATE TABLE sa_tbl_default_test "
40+
"(id integer,"
41+
" string_length integer, "
42+
"number integer,"
43+
" description VARCHAR(255), "
44+
"created_at DATETIME(6), "
45+
"enabled TINYINT)")
46+
47+
48+
@pytest.mark.run_loop
49+
async def test_default_fields(make_engine):
50+
engine = await make_engine()
51+
await start(engine)
52+
async with engine.acquire() as conn:
53+
await conn.execute(table.insert().values())
54+
res = await conn.execute(table.select())
55+
row = await res.fetchone()
56+
assert row.string_length == 6
57+
assert row.number == 100
58+
assert row.description == 'default test'
59+
assert row.enabled is True
60+
assert type(row.created_at) == datetime.datetime
61+
62+
63+
@pytest.mark.run_loop
64+
async def test_default_fields_isnull(make_engine):
65+
engine = await make_engine()
66+
await start(engine)
67+
async with engine.acquire() as conn:
68+
created_at = None
69+
enabled = False
70+
await conn.execute(table.insert().values(
71+
enabled=enabled,
72+
created_at=created_at,
73+
))
74+
75+
res = await conn.execute(table.select())
76+
row = await res.fetchone()
77+
assert row.number == 100
78+
assert row.string_length == 6
79+
assert row.description == 'default test'
80+
assert row.enabled == enabled
81+
assert row.created_at == created_at
82+
83+
84+
async def test_default_fields_edit(make_engine):
85+
engine = await make_engine()
86+
await start(engine)
87+
async with engine.acquire() as conn:
88+
created_at = datetime.datetime.now()
89+
description = 'new descr'
90+
enabled = False
91+
number = 111
92+
await conn.execute(table.insert().values(
93+
description=description,
94+
enabled=enabled,
95+
created_at=created_at,
96+
number=number,
97+
))
98+
99+
res = await conn.execute(table.select())
100+
row = await res.fetchone()
101+
assert row.number == number
102+
assert row.string_length == 6
103+
assert row.description == description
104+
assert row.enabled == enabled
105+
assert row.created_at == created_at

0 commit comments

Comments
 (0)