Skip to content

Commit a923621

Browse files
vlansejettify
authored andcommitted
Fix handling of user-defined types for sqlalchemy (#291)
* fix handling of user-defined types for sqlalchemy #290 * rename test class
1 parent 5b20a8a commit a923621

File tree

5 files changed

+126
-7
lines changed

5 files changed

+126
-7
lines changed

CHANGES.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
Changes
22
-------
33

4+
0.0.15 (2018-05-13)
5+
^^^^^^^^^^^^^^^^^^^
6+
7+
* Fixed handling of user-defined types for sqlalchemy #290
8+
9+
410
0.0.14 (2018-04-22)
511
^^^^^^^^^^^^^^^^^^^
612

aiomysql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from .cursors import Cursor, SSCursor, DictCursor, SSDictCursor
3434
from .pool import create_pool, Pool
3535

36-
__version__ = '0.0.14'
36+
__version__ = '0.0.15'
3737

3838
__all__ = [
3939

aiomysql/sa/connection.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ async def _execute(self, query, *multiparams, **params):
7171
elif dp:
7272
dp = dp[0]
7373

74+
result_map = None
75+
7476
if isinstance(query, str):
7577
await cursor.execute(query, dp or None)
7678
elif isinstance(query, ClauseElement):
@@ -97,18 +99,23 @@ async def _execute(self, query, *multiparams, **params):
9799
processed_parameters.append(params)
98100
post_processed_params = self._dialect.execute_sequence_format(
99101
processed_parameters)
102+
result_map = compiled._result_columns
103+
100104
else:
101105
if dp:
102106
raise exc.ArgumentError("Don't mix sqlalchemy DDL clause "
103107
"and execution with parameters")
104108
post_processed_params = [compiled.construct_params()]
109+
result_map = None
105110
await cursor.execute(str(compiled), post_processed_params[0])
106111
else:
107112
raise exc.ArgumentError("sql statement should be str or "
108113
"SQLAlchemy data "
109114
"selection/modification clause")
110115

111-
ret = await create_result_proxy(self, cursor, self._dialect)
116+
ret = await create_result_proxy(
117+
self, cursor, self._dialect, result_map
118+
)
112119
self._weak_results.add(ret)
113120
return ret
114121

aiomysql/sa/result.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from . import exc
1010

1111

12-
async def create_result_proxy(connection, cursor, dialect):
13-
result_proxy = ResultProxy(connection, cursor, dialect)
12+
async def create_result_proxy(connection, cursor, dialect, result_map):
13+
result_proxy = ResultProxy(connection, cursor, dialect, result_map)
1414
await result_proxy._prepare()
1515
return result_proxy
1616

@@ -95,6 +95,12 @@ class ResultMetaData:
9595
def __init__(self, result_proxy, metadata):
9696
self._processors = processors = []
9797

98+
result_map = {}
99+
100+
if result_proxy._result_map:
101+
result_map = {elem[0]: elem[3] for elem in
102+
result_proxy._result_map}
103+
98104
# We do not strictly need to store the processor in the key mapping,
99105
# though it is faster in the Python version (probably because of the
100106
# saved attribute lookup self._processors)
@@ -124,8 +130,13 @@ def __init__(self, result_proxy, metadata):
124130
# if dialect.requires_name_normalize:
125131
# colname = dialect.normalize_name(colname)
126132

127-
name, obj, type_ = \
128-
colname, None, typemap.get(coltype, sqltypes.NULLTYPE)
133+
name, obj, type_ = (
134+
colname,
135+
None,
136+
result_map.get(
137+
colname,
138+
typemap.get(coltype, sqltypes.NULLTYPE))
139+
)
129140

130141
processor = type_._cached_result_processor(dialect, coltype)
131142

@@ -223,13 +234,14 @@ class ResultProxy:
223234
the originating SQL statement that produced this result set.
224235
"""
225236

226-
def __init__(self, connection, cursor, dialect):
237+
def __init__(self, connection, cursor, dialect, result_map):
227238
self._dialect = dialect
228239
self._closed = False
229240
self._cursor = cursor
230241
self._connection = connection
231242
self._rowcount = cursor.rowcount
232243
self._lastrowid = cursor.lastrowid
244+
self._result_map = result_map
233245

234246
async def _prepare(self):
235247
loop = self._connection.connection.loop

tests/sa/test_sa_types.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import asyncio
2+
from aiomysql import connect, sa
3+
from enum import IntEnum
4+
5+
import os
6+
import unittest
7+
from unittest import mock
8+
9+
from sqlalchemy import MetaData, Table, Column, Integer, TypeDecorator
10+
11+
12+
class UserDefinedEnum(IntEnum):
13+
Value1 = 111
14+
Value2 = 222
15+
16+
17+
class IntEnumField(TypeDecorator):
18+
impl = Integer
19+
20+
def __init__(self, enum_class, *arg, **kw):
21+
TypeDecorator.__init__(self, *arg, **kw)
22+
self.enum_class = enum_class
23+
24+
def process_bind_param(self, value, dialect):
25+
""" From python to DB """
26+
if value is None:
27+
return None
28+
elif not isinstance(value, self.enum_class):
29+
return self.enum_class(value).value
30+
else:
31+
return value.value
32+
33+
def process_result_value(self, value, dialect):
34+
""" From DB to Python """
35+
if value is None:
36+
return None
37+
38+
return self.enum_class(value)
39+
40+
41+
meta = MetaData()
42+
tbl = Table('sa_test_type_tbl', meta,
43+
Column('id', Integer, nullable=False,
44+
primary_key=True),
45+
Column('val', IntEnumField(enum_class=UserDefinedEnum)))
46+
47+
48+
class TestSATypes(unittest.TestCase):
49+
def setUp(self):
50+
self.loop = asyncio.new_event_loop()
51+
asyncio.set_event_loop(None)
52+
self.host = os.environ.get('MYSQL_HOST', 'localhost')
53+
self.port = int(os.environ.get('MYSQL_PORT', 3306))
54+
self.user = os.environ.get('MYSQL_USER', 'root')
55+
self.db = os.environ.get('MYSQL_DB', 'test_pymysql')
56+
self.password = os.environ.get('MYSQL_PASSWORD', '')
57+
58+
def tearDown(self):
59+
self.loop.close()
60+
61+
async def connect(self, **kwargs):
62+
conn = await connect(db=self.db,
63+
user=self.user,
64+
password=self.password,
65+
host=self.host,
66+
loop=self.loop,
67+
port=self.port,
68+
**kwargs)
69+
await conn.autocommit(True)
70+
cur = await conn.cursor()
71+
await cur.execute("DROP TABLE IF EXISTS sa_test_type_tbl")
72+
await cur.execute("CREATE TABLE sa_test_type_tbl "
73+
"(id serial, val bigint)")
74+
await cur._connection.commit()
75+
engine = mock.Mock()
76+
engine.dialect = sa.engine._dialect
77+
return sa.SAConnection(conn, engine)
78+
79+
def test_values(self):
80+
async def go():
81+
conn = await self.connect()
82+
83+
await conn.execute(tbl.insert().values(
84+
val=UserDefinedEnum.Value1)
85+
)
86+
result = await conn.execute(tbl.select().where(
87+
tbl.c.val == UserDefinedEnum.Value1)
88+
)
89+
data = await result.fetchone()
90+
self.assertEqual(
91+
data['val'], UserDefinedEnum.Value1
92+
)
93+
94+
self.loop.run_until_complete(go())

0 commit comments

Comments
 (0)