Skip to content

Commit f712904

Browse files
scp10011terrycain
authored andcommitted
Deserialize json type data and Inherit multiple types of cursors at the same time (#343)
* Json Deserialization Cursor * Inherit multiple types of cursors at the same time * SS Cursor Modify the result by _conv_row * Mysql mariadb type is inconsistent JSON is only an alias for LONGBLOB for compatibility reasons with MySQL #https://mariadb.com/kb/en/library/json-data-type * Test against mysql
1 parent 9e48e85 commit f712904

File tree

5 files changed

+252
-8
lines changed

5 files changed

+252
-8
lines changed

.travis.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ matrix:
2626
env: PYTHONASYNCIODEBUG=
2727
addons:
2828
mariadb: 10.1
29+
- python: 3.6
30+
env: PYTHONASYNCIODEBUG=
31+
addons:
32+
mysql: 5.7
2933

3034

3135
before_script:

aiomysql/connection.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,11 @@
3737
from pymysql.connections import LoadLocalPacketWrapper
3838
from pymysql.connections import lenenc_int
3939

40-
4140
# from aiomysql.utils import _convert_to_str
4241
from .cursors import Cursor
4342
from .utils import _ConnectionContextManager, _ContextManager
4443
from .log import logger
4544

46-
4745
DEFAULT_USER = getpass.getuser()
4846

4947

@@ -389,7 +387,7 @@ def escape_string(self, s):
389387
return s.replace("'", "''")
390388
return escape_string(s)
391389

392-
def cursor(self, cursor=None):
390+
def cursor(self, *cursors):
393391
"""Instantiates and returns a cursor
394392
395393
By default, :class:`Cursor` is returned. It is possible to also give a
@@ -402,11 +400,19 @@ def cursor(self, cursor=None):
402400
"""
403401
self._ensure_alive()
404402
self._last_usage = self._loop.time()
405-
if cursor is not None and not issubclass(cursor, Cursor):
403+
try:
404+
if cursors and \
405+
any(not issubclass(cursor, Cursor) for cursor in cursors):
406+
raise TypeError('Custom cursor must be subclass of Cursor')
407+
except TypeError:
406408
raise TypeError('Custom cursor must be subclass of Cursor')
407-
408-
if cursor:
409-
cur = cursor(self, self._echo)
409+
if cursors and len(cursors) == 1:
410+
cur = cursors[0](self, self._echo)
411+
elif cursors:
412+
cursor_name = ''.join(map(lambda x: x.__name__, cursors)) \
413+
.replace('Cursor', '') + 'Cursor'
414+
cursor_class = type(cursor_name, cursors, {})
415+
cur = cursor_class(self, self._echo)
410416
else:
411417
cur = self.cursorclass(self, self._echo)
412418
fut = self._loop.create_future()
@@ -1057,6 +1063,7 @@ def __del__(self):
10571063
warnings.warn("Unclosed connection {!r}".format(self),
10581064
ResourceWarning)
10591065
self.close()
1066+
10601067
Warning = Warning
10611068
Error = Error
10621069
InterfaceError = InterfaceError

aiomysql/cursors.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import re
2+
import json
23
import warnings
4+
import contextlib
35

46
from pymysql.err import (
57
Warning, Error, InterfaceError, DataError,
68
DatabaseError, OperationalError, IntegrityError, InternalError,
79
NotSupportedError, ProgrammingError)
810

911
from .log import logger
10-
12+
from .connection import FIELD_TYPE
1113

1214
# https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/cursors.py#L11-L18
1315

@@ -515,6 +517,41 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
515517
return
516518

517519

520+
class _DeserializationCursorMixin:
521+
async def _do_get_result(self):
522+
await super()._do_get_result()
523+
if self._rows:
524+
self._rows = [self._deserialization_row(r) for r in self._rows]
525+
526+
def _deserialization_row(self, row):
527+
if row is None:
528+
return None
529+
if isinstance(row, dict):
530+
dict_flag = True
531+
else:
532+
row = list(row)
533+
dict_flag = False
534+
for index, (name, field_type, *n) in enumerate(self._description):
535+
if field_type == FIELD_TYPE.JSON:
536+
point = name if dict_flag else index
537+
with contextlib.suppress(ValueError, TypeError):
538+
row[point] = json.loads(row[point])
539+
if dict_flag:
540+
return row
541+
else:
542+
return tuple(row)
543+
544+
def _conv_row(self, row):
545+
if row is None:
546+
return None
547+
row = super()._conv_row(row)
548+
return self._deserialization_row(row)
549+
550+
551+
class DeserializationCursor(_DeserializationCursorMixin, Cursor):
552+
"""A cursor automatic deserialization of json type fields"""
553+
554+
518555
class _DictCursorMixin:
519556
# You can override this to use OrderedDict or other dict-like types.
520557
dict_type = dict
@@ -536,6 +573,7 @@ async def _do_get_result(self):
536573
def _conv_row(self, row):
537574
if row is None:
538575
return None
576+
row = super()._conv_row(row)
539577
return self.dict_type(zip(self._fields, row))
540578

541579

examples/example_cursors.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import asyncio
2+
import aiomysql
3+
4+
5+
async def test_example(loop):
6+
conn = await aiomysql.connect(host='127.0.0.1', port=3306,
7+
user='root', password='', db='mysql',
8+
loop=loop)
9+
sql = "SELECT 1 `id`, JSON_OBJECT('key1', 1, 'key2', 'abc') obj"
10+
async with conn.cursor(aiomysql.cursors.DeserializationCursor,
11+
aiomysql.cursors.DictCursor) as cur:
12+
await cur.execute(sql)
13+
print(cur.description)
14+
r = await cur.fetchall()
15+
print(r)
16+
conn.close()
17+
18+
19+
loop = asyncio.get_event_loop()
20+
loop.run_until_complete(test_example(loop))

tests/test_deserialize_cursor.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import copy
2+
import asyncio
3+
4+
import aiomysql.cursors
5+
from tests import base
6+
from tests._testutils import run_until_complete
7+
8+
9+
class TestDeserializeCursor(base.AIOPyMySQLTestCase):
10+
bob = ("bob", 21, {"k1": "pretty", "k2": [18, 25]})
11+
jim = ("jim", 56, {"k1": "rich", "k2": [20, 60]})
12+
fred = ("fred", 100, {"k1": "longevity", "k2": [100, 160]})
13+
havejson = True
14+
15+
cursor_type = aiomysql.cursors.DeserializationCursor
16+
17+
def setUp(self):
18+
super(TestDeserializeCursor, self).setUp()
19+
self.conn = conn = self.connections[0]
20+
21+
@asyncio.coroutine
22+
def prepare():
23+
c = yield from conn.cursor(self.cursor_type)
24+
25+
# create a table ane some data to query
26+
yield from c.execute("drop table if exists deserialize_cursor")
27+
yield from c.execute("select VERSION()")
28+
v = yield from c.fetchone()
29+
version, *db_type = v[0].split('-', 1)
30+
version = float(".".join(version.split('.', 2)[:2]))
31+
ismariadb = db_type and 'mariadb' in db_type[0].lower()
32+
if ismariadb or version < 5.7:
33+
yield from c.execute(
34+
"""CREATE TABLE deserialize_cursor
35+
(name char(20), age int , claim text)""")
36+
self.havejson = False
37+
else:
38+
yield from c.execute(
39+
"""CREATE TABLE deserialize_cursor
40+
(name char(20), age int , claim json)""")
41+
data = [("bob", 21, '{"k1": "pretty", "k2": [18, 25]}'),
42+
("jim", 56, '{"k1": "rich", "k2": [20, 60]}'),
43+
("fred", 100, '{"k1": "longevity", "k2": [100, 160]}')]
44+
yield from c.executemany("insert into deserialize_cursor values "
45+
"(%s,%s,%s)",
46+
data)
47+
48+
self.loop.run_until_complete(prepare())
49+
50+
def tearDown(self):
51+
@asyncio.coroutine
52+
def shutdown():
53+
c = yield from self.conn.cursor()
54+
yield from c.execute("drop table deserialize_cursor;")
55+
56+
self.loop.run_until_complete(shutdown())
57+
super(TestDeserializeCursor, self).tearDown()
58+
59+
@run_until_complete
60+
def test_deserialize_cursor(self):
61+
if not self.havejson:
62+
return
63+
bob, jim, fred = copy.deepcopy(self.bob), copy.deepcopy(
64+
self.jim), copy.deepcopy(self.fred)
65+
# all assert test compare to the structure as would come
66+
# out from MySQLdb
67+
conn = self.conn
68+
c = yield from conn.cursor(self.cursor_type)
69+
70+
# pull back the single row dict for bob and check
71+
yield from c.execute("SELECT * from deserialize_cursor "
72+
"where name='bob'")
73+
r = yield from c.fetchone()
74+
self.assertEqual(bob, r, "fetchone via DeserializeCursor failed")
75+
# same again, but via fetchall => tuple)
76+
yield from c.execute("SELECT * from deserialize_cursor "
77+
"where name='bob'")
78+
r = yield from c.fetchall()
79+
self.assertEqual([bob], r,
80+
"fetch a 1 row result via fetchall failed via "
81+
"DeserializeCursor")
82+
# get all 3 row via fetchall
83+
yield from c.execute("SELECT * from deserialize_cursor")
84+
r = yield from c.fetchall()
85+
self.assertEqual([bob, jim, fred], r,
86+
"fetchall failed via DictCursor")
87+
88+
# get all 2 row via fetchmany
89+
yield from c.execute("SELECT * from deserialize_cursor")
90+
r = yield from c.fetchmany(2)
91+
self.assertEqual([bob, jim], r, "fetchmany failed via DictCursor")
92+
yield from c.execute('commit')
93+
94+
@run_until_complete
95+
def test_deserialize_cursor_low_version(self):
96+
if self.havejson:
97+
return
98+
bob = ("bob", 21, '{"k1": "pretty", "k2": [18, 25]}')
99+
jim = ("jim", 56, '{"k1": "rich", "k2": [20, 60]}')
100+
fred = ("fred", 100, '{"k1": "longevity", "k2": [100, 160]}')
101+
# all assert test compare to the structure as would come
102+
# out from MySQLdb
103+
conn = self.conn
104+
c = yield from conn.cursor(self.cursor_type)
105+
106+
# pull back the single row dict for bob and check
107+
yield from c.execute("SELECT * from deserialize_cursor "
108+
"where name='bob'")
109+
r = yield from c.fetchone()
110+
self.assertEqual(bob, r, "fetchone via DeserializeCursor failed")
111+
# same again, but via fetchall => tuple)
112+
yield from c.execute("SELECT * from deserialize_cursor "
113+
"where name='bob'")
114+
r = yield from c.fetchall()
115+
self.assertEqual([bob], r,
116+
"fetch a 1 row result via fetchall failed via "
117+
"DeserializeCursor")
118+
# get all 3 row via fetchall
119+
yield from c.execute("SELECT * from deserialize_cursor")
120+
r = yield from c.fetchall()
121+
self.assertEqual([bob, jim, fred], r,
122+
"fetchall failed via DictCursor")
123+
124+
# get all 2 row via fetchmany
125+
yield from c.execute("SELECT * from deserialize_cursor")
126+
r = yield from c.fetchmany(2)
127+
self.assertEqual([bob, jim], r, "fetchmany failed via DictCursor")
128+
yield from c.execute('commit')
129+
130+
@run_until_complete
131+
def test_deserializedictcursor(self):
132+
if not self.havejson:
133+
return
134+
bob = {'name': 'bob', 'age': 21,
135+
'claim': {"k1": "pretty", "k2": [18, 25]}}
136+
conn = self.conn
137+
c = yield from conn.cursor(aiomysql.cursors.DeserializationCursor,
138+
aiomysql.cursors.DictCursor)
139+
yield from c.execute("SELECT * from deserialize_cursor "
140+
"where name='bob'")
141+
r = yield from c.fetchall()
142+
self.assertEqual([bob], r,
143+
"fetch a 1 row result via fetchall failed via "
144+
"DeserializationCursor")
145+
146+
@run_until_complete
147+
def test_ssdeserializecursor(self):
148+
if not self.havejson:
149+
return
150+
conn = self.conn
151+
c = yield from conn.cursor(aiomysql.cursors.SSCursor,
152+
aiomysql.cursors.DeserializationCursor)
153+
yield from c.execute("SELECT * from deserialize_cursor "
154+
"where name='bob'")
155+
r = yield from c.fetchall()
156+
self.assertEqual([self.bob], r,
157+
"fetch a 1 row result via fetchall failed via "
158+
"DeserializationCursor")
159+
160+
@run_until_complete
161+
def test_ssdeserializedictcursor(self):
162+
if not self.havejson:
163+
return
164+
bob = {'name': 'bob', 'age': 21,
165+
'claim': {"k1": "pretty", "k2": [18, 25]}}
166+
conn = self.conn
167+
c = yield from conn.cursor(aiomysql.cursors.SSCursor,
168+
aiomysql.cursors.DeserializationCursor,
169+
aiomysql.cursors.DictCursor)
170+
yield from c.execute("SELECT * from deserialize_cursor "
171+
"where name='bob'")
172+
r = yield from c.fetchall()
173+
self.assertEqual([bob], r,
174+
"fetch a 1 row result via fetchall failed via "
175+
"DeserializationCursor")

0 commit comments

Comments
 (0)