Skip to content

Commit 4d2708c

Browse files
committed
Handle correct disconnect for failed PostgreSQL transactions
1 parent df63ae7 commit 4d2708c

File tree

5 files changed

+35
-6
lines changed

5 files changed

+35
-6
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ Pass database URLs for those you want to run the tests against. Comma separated
105105
list.
106106

107107
```bash
108-
BASED_TEST_DB_URLS='postgresql://postgres:postgres@localhost:5432/postgres,mysql://root:[email protected]:3306/mariadb' make test`
108+
BASED_TEST_DB_URLS='postgresql://postgres:postgres@localhost:5432/postgres,mysql://root:[email protected]:3306/mariadb' make test
109109
```
110110

111111
## TODO

based/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.6.1"
1+
__version__ = "0.7.0"
22

33
from based.backends import Session
44
from based.database import Database

based/backends/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class Session:
118118
_conn: typing.Any
119119
_dialect: Dialect
120120
_transaction_stack: typing.List[str]
121+
transaction_failed: bool
121122

122123
def __init__( # noqa: D107
123124
self,
@@ -127,6 +128,7 @@ def __init__( # noqa: D107
127128
self._conn = conn
128129
self._dialect = dialect
129130
self._transaction_stack = []
131+
self.transaction_failed = False
130132

131133
async def _execute(
132134
self,
@@ -154,7 +156,11 @@ async def _execute(
154156
cursor:
155157
A cursor returned by the database driver after executing the query.
156158
"""
157-
return await self._conn.execute(query, params)
159+
try:
160+
return await self._conn.execute(query, params)
161+
except Exception:
162+
self.transaction_failed = True
163+
raise
158164

159165
def _compile_query(
160166
self,

based/backends/postgresql.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,20 @@ async def _session(self) -> typing.AsyncGenerator["Session", None]:
8484
await session.cancel_transaction()
8585
raise
8686
else:
87-
await session.commit_transaction()
87+
if session.transaction_failed:
88+
await session.cancel_transaction()
89+
else:
90+
await session.commit_transaction()
8891
else:
8992
try:
9093
yield session
9194
except Exception:
9295
await connection.rollback()
9396
raise
9497
else:
95-
await connection.commit()
98+
if session.transaction_failed:
99+
await connection.rollback()
100+
else:
101+
await connection.commit()
96102
finally:
97103
await self._pool.putconn(connection)

tests/test_backend.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import typing
23

34
import pytest
@@ -33,7 +34,9 @@ async def test_database_force_rollback_with_lock(
3334
title, year = gen_movie()
3435

3536
async with based.Database(
36-
database_url, force_rollback=True, use_lock=True,
37+
database_url,
38+
force_rollback=True,
39+
use_lock=True,
3740
) as database:
3841
async with database.session() as session:
3942
query = table.insert().values(title=title, year=year)
@@ -85,3 +88,17 @@ async def test_abstract_backend(database_url: str):
8588

8689
with pytest.raises(NotImplementedError):
8790
await backend.disconnect()
91+
92+
93+
async def test_disconnect_with_failed_transaction_force_rollback(database_url: str):
94+
async with based.Database(database_url, force_rollback=True) as database:
95+
async with database.session() as session:
96+
with contextlib.suppress(Exception):
97+
await session.execute("SELECT 1 FROM nonexistent;")
98+
99+
100+
async def test_disconnect_with_failed_transaction_no_force_rollback(database_url: str):
101+
async with based.Database(database_url, force_rollback=False) as database:
102+
async with database.session() as session:
103+
with contextlib.suppress(Exception):
104+
await session.execute("SELECT 1 FROM nonexistent;")

0 commit comments

Comments
 (0)