Skip to content

Commit aec31dd

Browse files
vlansejettify
authored andcommitted
ability to execute precompiled sqlalchemy queries (#294)
* ability to execute precompiled sqlalchemy queries * global cache for compiled queries * update formatting * use only query as a key in compiled cache
1 parent a6f3ee9 commit aec31dd

File tree

5 files changed

+169
-10
lines changed

5 files changed

+169
-10
lines changed

CHANGES.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
Changes
22
-------
33

4+
0.0.16 (2018-05-21)
5+
^^^^^^^^^^^^^^^^^^^
6+
7+
* Added ability to execute precompiled sqlalchemy queries
8+
9+
410
0.0.15 (2018-05-20)
511
^^^^^^^^^^^^^^^^^^^
612

aiomysql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from .cursors import Cursor, SSCursor, DictCursor, SSDictCursor
3434
from .pool import create_pool, Pool
3535

36-
__version__ = '0.0.15'
36+
__version__ = '0.0.16'
3737

3838
__all__ = [
3939

aiomysql/sa/connection.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515

1616
class SAConnection:
1717

18-
def __init__(self, connection, engine):
18+
def __init__(self, connection, engine, compiled_cache=None):
1919
self._connection = connection
2020
self._transaction = None
2121
self._savepoint_seq = 0
2222
self._weak_results = weakref.WeakSet()
2323
self._engine = engine
2424
self._dialect = engine.dialect
25+
self._compiled_cache = compiled_cache
2526

2627
def execute(self, query, *multiparams, **params):
2728
"""Executes a SQL query with optional parameters.
@@ -76,8 +77,18 @@ async def _execute(self, query, *multiparams, **params):
7677
if isinstance(query, str):
7778
await cursor.execute(query, dp or None)
7879
elif isinstance(query, ClauseElement):
79-
compiled = query.compile(dialect=self._dialect)
80-
# parameters = compiled.params
80+
if self._compiled_cache is not None:
81+
key = query
82+
compiled = self._compiled_cache.get(key)
83+
if not compiled:
84+
compiled = query.compile(dialect=self._dialect)
85+
if dp and dp.keys() == compiled.params.keys() \
86+
or not (dp or compiled.params):
87+
# we only want queries with bound params in cache
88+
self._compiled_cache[key] = compiled
89+
else:
90+
compiled = query.compile(dialect=self._dialect)
91+
8192
if not isinstance(query, DDLElement):
8293
if dp and isinstance(dp, (list, tuple)):
8394
if isinstance(query, UpdateBase):

aiomysql/sa/engine.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@
2020

2121

2222
def create_engine(minsize=1, maxsize=10, loop=None,
23-
dialect=_dialect, pool_recycle=-1, **kwargs):
23+
dialect=_dialect, pool_recycle=-1, compiled_cache=None,
24+
**kwargs):
2425
"""A coroutine for Engine creation.
2526
2627
Returns Engine instance with embedded connection pool.
2728
2829
The pool has *minsize* opened connections to PostgreSQL server.
2930
"""
3031
coro = _create_engine(minsize=minsize, maxsize=maxsize, loop=loop,
31-
dialect=dialect, pool_recycle=pool_recycle, **kwargs)
32+
dialect=dialect, pool_recycle=pool_recycle,
33+
compiled_cache=compiled_cache, **kwargs)
3234
compatible_cursor_classes = [Cursor]
3335
# Without provided kwarg, default is default cursor from Connection class
3436
if kwargs.get('cursorclass', Cursor) not in compatible_cursor_classes:
@@ -38,7 +40,8 @@ def create_engine(minsize=1, maxsize=10, loop=None,
3840

3941

4042
async def _create_engine(minsize=1, maxsize=10, loop=None,
41-
dialect=_dialect, pool_recycle=-1, **kwargs):
43+
dialect=_dialect, pool_recycle=-1,
44+
compiled_cache=None, **kwargs):
4245

4346
if loop is None:
4447
loop = asyncio.get_event_loop()
@@ -47,7 +50,7 @@ async def _create_engine(minsize=1, maxsize=10, loop=None,
4750
pool_recycle=pool_recycle, **kwargs)
4851
conn = await pool.acquire()
4952
try:
50-
return Engine(dialect, pool, **kwargs)
53+
return Engine(dialect, pool, compiled_cache=compiled_cache, **kwargs)
5154
finally:
5255
pool.release(conn)
5356

@@ -61,9 +64,10 @@ class Engine:
6164
create_engine coroutine.
6265
"""
6366

64-
def __init__(self, dialect, pool, **kwargs):
67+
def __init__(self, dialect, pool, compiled_cache=None, **kwargs):
6568
self._dialect = dialect
6669
self._pool = pool
70+
self._compiled_cache = compiled_cache
6771
self._conn_kw = kwargs
6872

6973
@property
@@ -124,7 +128,7 @@ def acquire(self):
124128

125129
async def _acquire(self):
126130
raw = await self._pool.acquire()
127-
conn = SAConnection(raw, self)
131+
conn = SAConnection(raw, self, compiled_cache=self._compiled_cache)
128132
return conn
129133

130134
def release(self, conn):

tests/sa/test_sa_compiled_cache.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import asyncio
2+
from aiomysql import sa
3+
from sqlalchemy import bindparam
4+
5+
import os
6+
import unittest
7+
8+
from sqlalchemy import MetaData, Table, Column, Integer, String
9+
10+
meta = MetaData()
11+
tbl = Table('sa_tbl_cache_test', meta,
12+
Column('id', Integer, nullable=False,
13+
primary_key=True),
14+
Column('val', String(255)))
15+
16+
17+
class TestCompiledCache(unittest.TestCase):
18+
def setUp(self):
19+
self.loop = asyncio.new_event_loop()
20+
asyncio.set_event_loop(None)
21+
self.host = os.environ.get('MYSQL_HOST', 'localhost')
22+
self.port = int(os.environ.get('MYSQL_PORT', 3306))
23+
self.user = os.environ.get('MYSQL_USER', 'root')
24+
self.db = os.environ.get('MYSQL_DB', 'test_pymysql')
25+
self.password = os.environ.get('MYSQL_PASSWORD', '')
26+
self.engine = self.loop.run_until_complete(self.make_engine())
27+
self.loop.run_until_complete(self.start())
28+
29+
def tearDown(self):
30+
self.engine.terminate()
31+
self.loop.run_until_complete(self.engine.wait_closed())
32+
self.loop.close()
33+
34+
async def make_engine(self, **kwargs):
35+
return (await sa.create_engine(db=self.db,
36+
user=self.user,
37+
password=self.password,
38+
host=self.host,
39+
port=self.port,
40+
loop=self.loop,
41+
minsize=10,
42+
**kwargs))
43+
44+
async def start(self):
45+
async with self.engine.acquire() as conn:
46+
tx = await conn.begin()
47+
await conn.execute("DROP TABLE IF EXISTS "
48+
"sa_tbl_cache_test")
49+
await conn.execute("CREATE TABLE sa_tbl_cache_test"
50+
"(id serial, val varchar(255))")
51+
await conn.execute(tbl.insert().values(val='some_val_1'))
52+
await conn.execute(tbl.insert().values(val='some_val_2'))
53+
await conn.execute(tbl.insert().values(val='some_val_3'))
54+
await tx.commit()
55+
56+
def test_cache(self):
57+
async def go():
58+
cache = dict()
59+
engine = await self.make_engine(compiled_cache=cache)
60+
async with engine.acquire() as conn:
61+
# check select with params not added to cache
62+
q = tbl.select().where(tbl.c.val == 'some_val_1')
63+
cursor = await conn.execute(q)
64+
row = await cursor.fetchone()
65+
self.assertEqual('some_val_1', row.val)
66+
self.assertEqual(0, len(cache))
67+
68+
# check select with bound params added to cache
69+
select_by_val = tbl.select().where(
70+
tbl.c.val == bindparam('value')
71+
)
72+
cursor = await conn.execute(
73+
select_by_val, {'value': 'some_val_3'}
74+
)
75+
row = await cursor.fetchone()
76+
self.assertEqual('some_val_3', row.val)
77+
self.assertEqual(1, len(cache))
78+
79+
cursor = await conn.execute(
80+
select_by_val, value='some_val_2'
81+
)
82+
row = await cursor.fetchone()
83+
self.assertEqual('some_val_2', row.val)
84+
self.assertEqual(1, len(cache))
85+
86+
select_all = tbl.select()
87+
cursor = await conn.execute(select_all)
88+
rows = await cursor.fetchall()
89+
self.assertEqual(3, len(rows))
90+
self.assertEqual(2, len(cache))
91+
92+
# check insert with bound params not added to cache
93+
await conn.execute(tbl.insert().values(val='some_val_4'))
94+
self.assertEqual(2, len(cache))
95+
96+
# check insert with bound params added to cache
97+
q = tbl.insert().values(val=bindparam('value'))
98+
await conn.execute(q, value='some_val_5')
99+
self.assertEqual(3, len(cache))
100+
101+
await conn.execute(q, value='some_val_6')
102+
self.assertEqual(3, len(cache))
103+
104+
await conn.execute(q, {'value': 'some_val_7'})
105+
self.assertEqual(3, len(cache))
106+
107+
cursor = await conn.execute(select_all)
108+
rows = await cursor.fetchall()
109+
self.assertEqual(7, len(rows))
110+
self.assertEqual(3, len(cache))
111+
112+
# check update with params not added to cache
113+
q = tbl.update().where(
114+
tbl.c.val == 'some_val_1'
115+
).values(val='updated_val_1')
116+
await conn.execute(q)
117+
self.assertEqual(3, len(cache))
118+
cursor = await conn.execute(
119+
select_by_val, value='updated_val_1'
120+
)
121+
row = await cursor.fetchone()
122+
self.assertEqual('updated_val_1', row.val)
123+
124+
# check update with bound params added to cache
125+
q = tbl.update().where(
126+
tbl.c.val == bindparam('value')
127+
).values(val=bindparam('update'))
128+
await conn.execute(
129+
q, value='some_val_2', update='updated_val_2'
130+
)
131+
self.assertEqual(4, len(cache))
132+
cursor = await conn.execute(
133+
select_by_val, value='updated_val_2'
134+
)
135+
row = await cursor.fetchone()
136+
self.assertEqual('updated_val_2', row.val)
137+
138+
self.loop.run_until_complete(go())

0 commit comments

Comments
 (0)