diff --git a/.gitignore b/.gitignore index 94d5bb2e..c3b79f9e 100644 --- a/.gitignore +++ b/.gitignore @@ -67,3 +67,4 @@ target/ tests/fixtures/my.cnf .pytest_cache +/tests/test_reconnect.py diff --git a/aiomysql/cursors.py b/aiomysql/cursors.py index 3401bdbf..67d5a48d 100644 --- a/aiomysql/cursors.py +++ b/aiomysql/cursors.py @@ -2,6 +2,7 @@ import json import warnings import contextlib +import asyncio from pymysql.err import ( Warning, Error, InterfaceError, DataError, @@ -22,6 +23,14 @@ r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", re.IGNORECASE | re.DOTALL) +ERROR_CODES_FOR_RECONNECTING = [ + 1927, # ER_CONNECTION_KILLED + 1184, # ER_NEW_ABORTING_CONNECTION + 1152, # ER_ABORTING_CONNECTION, + 2003, # Can't connect to MySQL server + 2013, # Lost connection to MySQL server during query +] + class Cursor: """Cursor is used to interact with the database.""" @@ -236,11 +245,46 @@ async def execute(self, query, args=None): if args is not None: query = query % self._escape_args(args, conn) - await self._query(query) + try: + await self._query(query) + + except asyncio.CancelledError: + raise + + except Exception as main_error: + if not hasattr(main_error, 'args') or main_error.args[0] not in ERROR_CODES_FOR_RECONNECTING: + raise main_error + + logger.error(main_error) + sleep_time_list = [3] * 20 + sleep_time_list.insert(0, 1) + for attempt, sleep_time in enumerate(sleep_time_list): + try: + logger.warning('%s - Reconnecting to MySQL. Attempt %d of 21 for connection %s', conn._db, attempt + 1, id(conn)) + await conn.ping() + logger.info('%s - Successfully reconnected to MySQL after error for connection %s', conn._db, id(conn)) + await self._query(query) + break + + except asyncio.CancelledError: + raise + + except Exception as e: + if not hasattr(e, 'args') or e.args[0] not in ERROR_CODES_FOR_RECONNECTING: + break + + logger.error(e) + await asyncio.sleep(sleep_time) + + else: + logger.error('%s - Reconnecting to MySQL failed for connection %s', conn._db, id(conn)) + raise main_error + self._executed = query if self._echo: logger.info(query) logger.info("%r", args) + return self._rowcount async def executemany(self, query, args): diff --git a/aiomysql/log.py b/aiomysql/log.py index d632698e..fd6de6ba 100644 --- a/aiomysql/log.py +++ b/aiomysql/log.py @@ -3,3 +3,4 @@ # Name the logger after the package. logger = logging.getLogger(__package__) +logger.setLevel(logging.WARNING) diff --git a/aiomysql/pool.py b/aiomysql/pool.py index eaaddbe0..773bcc11 100644 --- a/aiomysql/pool.py +++ b/aiomysql/pool.py @@ -3,11 +3,15 @@ import asyncio import collections +import sys import warnings +from pymysql import OperationalError + +from .log import logger from .connection import connect from .utils import (_PoolContextManager, _PoolConnectionContextManager, - _PoolAcquireContextManager) + _PoolAcquireContextManager, TaskTransactionContextManager) def create_pool(minsize=1, maxsize=10, echo=False, pool_recycle=-1, @@ -50,6 +54,7 @@ def __init__(self, minsize, maxsize, echo, pool_recycle, loop, **kwargs): self._closed = False self._echo = echo self._recycle = pool_recycle + self._db = kwargs.get('db') @property def echo(self): @@ -71,6 +76,10 @@ def size(self): def freesize(self): return len(self._free) + @property + def db_name(self): + return self._db + async def clear(self): """Close all free connections in pool.""" async with self._cond: @@ -131,9 +140,32 @@ async def wait_closed(self): def acquire(self): """Acquire free connection from the pool.""" + if sys.version_info < (3, 7): + o_transaction_context_manager = TaskTransactionContextManager.get_transaction_context_manager(asyncio.Task.current_task()) + + else: + o_transaction_context_manager = TaskTransactionContextManager.get_transaction_context_manager(asyncio.current_task()) + + if o_transaction_context_manager: + return o_transaction_context_manager + coro = self._acquire() return _PoolAcquireContextManager(coro, self) + def acquire_with_transaction(self): + """Acquire free connection from the pool for transaction""" + if sys.version_info < (3, 7): + o_transaction_context_manager = TaskTransactionContextManager.get_transaction_context_manager(asyncio.Task.current_task()) + + else: + o_transaction_context_manager = TaskTransactionContextManager.get_transaction_context_manager(asyncio.current_task()) + + if o_transaction_context_manager: + return o_transaction_context_manager + + coro = self._acquire() + return TaskTransactionContextManager(coro, self) + async def _acquire(self): if self._closing: raise RuntimeError("Cannot acquire connection after closing pool") @@ -142,11 +174,12 @@ async def _acquire(self): await self._fill_free_pool(True) if self._free: conn = self._free.popleft() - assert not conn.closed, conn - assert conn not in self._used, (conn, self._used) + # assert not conn.closed, conn + # assert conn not in self._used, (conn, self._used) self._used.add(conn) return conn else: + logger.debug('%s - All connections (%d) are busy. Waiting for release connection', self._db, self.freesize) await self._cond.wait() async def _fill_free_pool(self, override_min): @@ -156,6 +189,7 @@ async def _fill_free_pool(self, override_min): while n < free_size: conn = self._free[-1] if conn._reader.at_eof() or conn._reader.exception(): + logger.debug('%s - Connection (%d) is removed from pool because of at_eof or exception', self._db, id(conn)) self._free.pop() conn.close() @@ -167,38 +201,56 @@ async def _fill_free_pool(self, override_min): self._free.pop() conn.close() - elif (self._recycle > -1 and - self._loop.time() - conn.last_usage > self._recycle): + elif self._recycle > -1 and self._loop.time() - conn.last_usage > self._recycle: + logger.debug('%s - Connection (%d) is removed from pool because of recycle time %d', self._db, id(conn), self._recycle) self._free.pop() conn.close() else: self._free.rotate() + n += 1 while self.size < self.minsize: - self._acquiring += 1 - try: - conn = await connect(echo=self._echo, loop=self._loop, - **self._conn_kwargs) - # raise exception if pool is closing - self._free.append(conn) - self._cond.notify() - finally: - self._acquiring -= 1 + await self.__create_new_connection() + if self._free: return if override_min and (not self.maxsize or self.size < self.maxsize): - self._acquiring += 1 + await self.__create_new_connection() + + async def __create_new_connection(self): + logger.debug('%s - Try to create new connection', self._db) + self._acquiring += 1 + try: try: - conn = await connect(echo=self._echo, loop=self._loop, - **self._conn_kwargs) - # raise exception if pool is closing - self._free.append(conn) - self._cond.notify() - finally: - self._acquiring -= 1 + conn = await connect(echo=self._echo, loop=self._loop, **self._conn_kwargs) + + except OperationalError as error: + logger.error(error) + sleep_time_list = [3] * 20 + for attempt, sleep_time in enumerate(sleep_time_list): + try: + logger.warning('%s - Connect to MySQL failed. Attempt %d of 20', self._db, attempt + 1) + conn = await connect(echo=self._echo, loop=self._loop, **self._conn_kwargs) + logger.info('%s - Successfully connect to MySQL after error', self._db) + break + + except OperationalError as e: + logger.error(e) + await asyncio.sleep(sleep_time) + + else: + logger.error('%s - Connect to MySQL failed', self._db) + raise error + + # raise exception if pool is closing + self._free.append(conn) + self._cond.notify() + + finally: + self._acquiring -= 1 async def _wakeup(self): async with self._cond: @@ -213,10 +265,10 @@ def release(self, conn): fut.set_result(None) if conn in self._terminated: - assert conn.closed, conn + # assert conn.closed, conn self._terminated.remove(conn) return fut - assert conn in self._used, (conn, self._used) + # assert conn in self._used, (conn, self._used) self._used.remove(conn) if not conn.closed: in_trans = conn.get_transaction_status() diff --git a/aiomysql/utils.py b/aiomysql/utils.py index 74ad99a7..190cb5d3 100644 --- a/aiomysql/utils.py +++ b/aiomysql/utils.py @@ -1,7 +1,11 @@ +import asyncio +import sys from collections.abc import Coroutine import struct +from .log import logger + def _pack_int24(n): return struct.pack(" 'TaskTransactionContextManager': + if sys.version_info < (3, 7): + task = asyncio.Task.current_task() + + else: + task = asyncio.current_task() + + if task in cls.__task_storage: + return cls.__task_storage[task] + + return TaskTransactionContextManager(coro, pool) + + @classmethod + def get_transaction_context_manager(cls, task=None) -> 'TaskTransactionContextManager': + if not task: + if sys.version_info < (3, 7): + task = asyncio.Task.current_task() + + else: + task = asyncio.current_task() + + return cls.__task_storage.get(task) + + def add_callback_on_commit(self, callback_func, **kwargs): + self._callback_list.append((callback_func, kwargs)) + + async def connection_begin(self): + if not self.__connection_transaction_begin: + await self._conn.begin() + self.__connection_transaction_begin = True + + async def connection_commit(self): + if self._counter <= 1: + await self._conn.commit() + for callback_func, kwargs in self._callback_list: + try: + if asyncio.iscoroutine(callback_func): + await callback_func(**kwargs) + + else: + callback_func(**kwargs) + + except Exception as e: + logger.exception(e) + + self._callback_list.clear() + self.__connection_committed = True + + async def connection_rollback(self): + await self._conn.rollback() + self._callback_list.clear() + self.__connection_rollbacked = True + + async def __aenter__(self): + self._counter += 1 + if not self._conn: + self._conn = await self._coro + + return TransactionConnection(self._conn) + + async def __aexit__(self, exc_type, exc, tb): + self._counter -= 1 + if self._counter <= 0: + if self.__connection_transaction_begin and not self.__connection_committed and not self.__connection_rollbacked: + if not exc_type: + logger.warning('sql operation was not committed. Try to commit by TaskTransactionContextManager') + await self.connection_commit() + + if sys.version_info < (3, 7): + self.__task_storage.pop(asyncio.Task.current_task(), None) + + else: + self.__task_storage.pop(asyncio.current_task(), None) + + try: + await self._pool.release(self._conn) + + finally: + self._pool = None + self._conn = None + self.__connection_transaction_begin = False + self.__connection_committed = False + self.__connection_rollbacked = False + + +class TransactionConnection: + + def __init__(self, conn): + self.__conn = conn + + def __str__(self): + return 'TransactionConnection ' + str(self.__conn) + + async def begin(self): + """Begin transaction.""" + if TaskTransactionContextManager.get_transaction_context_manager(): + await TaskTransactionContextManager.get_transaction_context_manager().connection_begin() + + else: + await self.__conn.begin() + + async def commit(self): + """Commit changes to stable storage.""" + if TaskTransactionContextManager.get_transaction_context_manager(): + await TaskTransactionContextManager.get_transaction_context_manager().connection_commit() + + else: + await self.__conn.commit() + + async def rollback(self): + """Roll back the current transaction.""" + if TaskTransactionContextManager.get_transaction_context_manager(): + await TaskTransactionContextManager.get_transaction_context_manager().connection_rollback() + + else: + await self.__conn.rollback() + + def __getattr__(self, item): + return getattr(self.__conn, item) + + @property + def conn(self): + return self.__conn diff --git a/tests/test_pool.py b/tests/test_pool.py index f2872d67..876764cd 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -557,3 +557,21 @@ async def test_pool_maxsize_unlimited_minsize_1(pool_creator, loop): async with pool.acquire() as conn: cur = await conn.cursor() await cur.execute('SELECT 1;') + + +@pytest.mark.run_loop +async def test_pool_maxsize_unlimited(pool_creator, loop): + pool = await pool_creator(minsize=0, maxsize=0) + + async with pool.acquire() as conn: + cur = await conn.cursor() + await cur.execute('SELECT 1;') + + +@pytest.mark.run_loop +async def test_pool_maxsize_unlimited_minsize_1(pool_creator, loop): + pool = await pool_creator(minsize=1, maxsize=0) + + async with pool.acquire() as conn: + cur = await conn.cursor() + await cur.execute('SELECT 1;') diff --git a/tests/test_transaction_context_manager.py b/tests/test_transaction_context_manager.py new file mode 100644 index 00000000..034b474a --- /dev/null +++ b/tests/test_transaction_context_manager.py @@ -0,0 +1,99 @@ +import asyncio +import logging + +from pymysql import InterfaceError, OperationalError + +import aiomysql + + +async def start_test(): + logging.getLogger().setLevel(logging.DEBUG) + logging.info('start_test') + host = '10.20.166.2' + port = 3306 + user = 'root' + password = 'Kub@Rozpruw@cz007' + db = 'manager' + pool: aiomysql.Pool = await aiomysql.create_pool( + host=host, user=user, password=password, db=db, + port=port, minsize=1, maxsize=2 + ) + + for x in range(2): + loop.create_task(test_connection1(f'name{x}', pool)) + + +async def test_connection1(name, pool: aiomysql.Pool): + sql = """ + SELECT 1; + """ + while True: + logging.info(f"{name}: pool_freesize %d", pool.freesize) + async with pool.acquire_with_transaction() as connection: + logging.info(connection) + await connection.begin() + async with connection.cursor() as cursor: + try: + await cursor.execute(sql) + await cursor.fetchone() + logging.info(f"{name}: OK") + + except asyncio.CancelledError: + raise + + except Exception as e: + logging.error(f"{name}: {e}") + # raise + + await test_connection2('test_connection2', pool) + await test_connection3('test_connection3', pool) + + await connection.rollback() + + await asyncio.sleep(1) + + +async def test_connection2(name, pool: aiomysql.Pool): + sql = """ + SELECT 1; + """ + async with pool.acquire_with_transaction() as connection: + logging.info(connection) + await connection.begin() + async with connection.cursor() as cursor: + try: + await cursor.execute(sql) + await cursor.fetchone() + logging.info(f"{name}: OK") + + except asyncio.CancelledError: + raise + + except Exception as e: + logging.error(f"{name}: {e}") + # raise + + +async def test_connection3(name, pool: aiomysql.Pool): + sql = """ + SELECT 1; + """ + async with pool.acquire() as connection: + logging.info(connection) + async with connection.cursor() as cursor: + try: + await cursor.execute(sql) + await cursor.fetchone() + logging.info(f"{name}: OK") + + except asyncio.CancelledError: + raise + + except Exception as e: + logging.error(f"{name}: {e}") + # raise + + +loop = asyncio.get_event_loop() +loop.run_until_complete(start_test()) +loop.run_forever()