diff --git a/databases/core.py b/databases/core.py index 8394ab5c..981e3c47 100644 --- a/databases/core.py +++ b/databases/core.py @@ -7,6 +7,7 @@ from types import TracebackType from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit +from pymysql.err import OperationalError from sqlalchemy import text from sqlalchemy.sql import ClauseElement @@ -391,25 +392,37 @@ async def start(self) -> "Transaction": async with self._connection._transaction_lock: is_root = not self._connection._transaction_stack await self._connection.__aenter__() - await self._transaction.start( - is_root=is_root, extra_options=self._extra_options - ) - self._connection._transaction_stack.append(self) + try: + await self._transaction.start( + is_root=is_root, extra_options=self._extra_options + ) + self._connection._transaction_stack.append(self) + except Exception as e: + await self._connection.__aexit__() + raise e return self async def commit(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() - await self._transaction.commit() - await self._connection.__aexit__() + try: + await self._transaction.commit() + except Exception as e: + raise e + finally: + await self._connection.__aexit__() async def rollback(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() - await self._transaction.rollback() - await self._connection.__aexit__() + try: + await self._transaction.rollback() + except Exception as e: + raise e + finally: + await self._connection.__aexit__() class _EmptyNetloc(str): diff --git a/requirements.txt b/requirements.txt index 0699d3cc..1394fcc8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ pytest==7.1.2 pytest-cov==3.0.0 starlette==0.20.4 requests==2.28.1 +types-PyMySQL==1.0.19.1 # Documentation mkdocs==1.3.1 @@ -29,3 +30,4 @@ mkautodoc==0.1.0 # Packaging twine==4.0.1 wheel==0.38.1 +