diff --git a/aiomysql/sa/connection.py b/aiomysql/sa/connection.py index f6cabab0..5dd423e2 100644 --- a/aiomysql/sa/connection.py +++ b/aiomysql/sa/connection.py @@ -103,7 +103,15 @@ async def _executemany(self, query, dps, cursor): "and execution with parameters" ) elif isinstance(query, ClauseElement): - compiled = query.compile(dialect=self._dialect) + if self._compiled_cache is not None: + key = query + compiled = self._compiled_cache.get(key) + if not compiled: + compiled = query.compile(dialect=self._dialect) + self._compiled_cache[key] = compiled + else: + compiled = query.compile(dialect=self._dialect) + params = [] is_update = isinstance(query, UpdateBase) for dp in dps: @@ -151,7 +159,6 @@ async def _execute(self, query, *multiparams, **params): compiled = query.compile(dialect=self._dialect) if dp and dp.keys() == compiled.params.keys() \ or not (dp or compiled.params): - # we only want queries with bound params in cache self._compiled_cache[key] = compiled else: compiled = query.compile(dialect=self._dialect) diff --git a/tests/sa/test_sa_compiled_cache.py b/tests/sa/test_sa_compiled_cache.py index 905b637d..bd711936 100644 --- a/tests/sa/test_sa_compiled_cache.py +++ b/tests/sa/test_sa_compiled_cache.py @@ -53,6 +53,88 @@ async def start(self): await conn.execute(tbl.insert().values(val='some_val_3')) await tx.commit() + def test_cache_executemany(self): + async def go(): + cache = dict() + engine = await self.make_engine(compiled_cache=cache) + async with engine.acquire() as conn: + select_all = tbl.select() + select_by_val = tbl.select().where( + tbl.c.val == bindparam('value') + ) + + # check insert with params not added to cache + await conn.execute(tbl.insert().values([ + {'val': 'some_val_100'}, + {'val': 'some_val_101'}, + {'val': 'some_val_102'}, + ])) + self.assertEqual(0, len(cache)) + + # check insert with bound param added to cache + q = tbl.insert().values(val=bindparam('value')) + test_values = [ + {'value': 'some_val_103'}, + {'value': 'some_val_104'}, + {'value': 'some_val_105'}, + ] + await conn.execute(q, test_values) + self.assertEqual(1, len(cache)) + + await conn.execute(q, test_values) + self.assertEqual(1, len(cache)) + + cursor = await conn.execute(select_all) + rows = await cursor.fetchall() + self.assertEqual(12, len(rows)) + self.assertEqual(2, len(cache)) + + # check update with bound params added to cache + q = tbl.update().where( + tbl.c.val == bindparam('value') + ).values(val=bindparam('update')) + + test_upd_values_1 = [ + {'value': 'some_val_100', 'update': 'updated_val_100'}, + {'value': 'some_val_101', 'update': 'updated_val_101'}, + {'value': 'some_val_102', 'update': 'updated_val_102'}, + ] + res = await conn.execute(q, test_upd_values_1) + self.assertEqual(3, res.rowcount) + self.assertEqual(3, len(cache)) + + test_upd_values_2 = [ + {'value': 'updated_val_100', 'update': 'updated_val_200'}, + {'value': 'updated_val_101', 'update': 'updated_val_201'}, + {'value': 'updated_val_102', 'update': 'updated_val_202'}, + ] + res = await conn.execute(q, test_upd_values_2) + self.assertEqual(3, res.rowcount) + self.assertEqual(3, len(cache)) + + for test_value in test_upd_values_2: + cursor = await conn.execute( + select_by_val, value=test_value['update'] + ) + row = await cursor.fetchone() + self.assertEqual(test_value['update'], row.val) + + self.assertEqual(4, len(cache)) + + # check delete with bound params added to cache + q = tbl.delete().where(tbl.c.val == bindparam('value')) + + test_del_values = [ + {'value': 'updated_val_200'}, + {'value': 'updated_val_201'}, + {'value': 'updated_val_202'}, + ] + res = await conn.execute(q, test_del_values) + self.assertEqual(3, res.rowcount) + self.assertEqual(5, len(cache)) + + self.loop.run_until_complete(go()) + def test_cache(self): async def go(): cache = dict()