Skip to content

Commit 30840b9

Browse files
authored
Merge pull request #14 from tarsil/fix/transactions
Fixing the concurrent usage of connections and transactions
2 parents 9db0605 + 3527fe6 commit 30840b9

File tree

7 files changed

+381
-27
lines changed

7 files changed

+381
-27
lines changed

databasez/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from databasez.core import Database, DatabaseURL
22

3-
__version__ = "0.4.0"
3+
__version__ = "0.5.0"
44

55
__all__ = ["Database", "DatabaseURL"]

databasez/core.py

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
import functools
44
import logging
55
import typing
6+
import weakref
67
from contextvars import ContextVar
78
from types import TracebackType
89
from urllib.parse import SplitResult, parse_qsl, unquote, urlencode, urlsplit
910

1011
from sqlalchemy import text
1112
from sqlalchemy.sql import ClauseElement
12-
from sqlalchemy.util._concurrency_py3k import greenlet_spawn
1313

1414
from databasez.importer import import_from_string
15-
from databasez.interfaces import DatabaseBackend, Record
15+
from databasez.interfaces import DatabaseBackend, Record, TransactionBackend
1616

1717
if typing.TYPE_CHECKING:
1818
from databasez.types import DictAny
@@ -35,6 +35,11 @@
3535
logger = logging.getLogger("databasez")
3636

3737

38+
ACTIVE_TRANSACTIONS: ContextVar[
39+
typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"]
40+
] = ContextVar("databasez:active_transactions", default=None)
41+
42+
3843
class Database:
3944
"""
4045
An abstraction on the top of the EncodeORM databases.Database object.
@@ -72,6 +77,7 @@ class Database:
7277
}
7378
DIRECT_URL_SCHEME = {"sqlite"}
7479
MANDATORY_FIELDS = ["host", "port", "user", "database"]
80+
_connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"
7581

7682
def __init__(
7783
self,
@@ -92,6 +98,7 @@ def __init__(
9298
self.url = DatabaseURL(_url) # type: ignore
9399
self.options = options
94100
self.is_connected = False
101+
self._connection_map = weakref.WeakKeyDictionary()
95102

96103
self._force_rollback = force_rollback
97104

@@ -100,9 +107,6 @@ def __init__(
100107
assert issubclass(backend_cls, DatabaseBackend)
101108
self._backend = backend_cls(self.url, **self.options)
102109

103-
# Connections are stored as task-local state.
104-
self._connection_context: ContextVar = ContextVar("connection_context")
105-
106110
# When `force_rollback=True` is used, we use a single global
107111
# connection, within a transaction that always rolls back.
108112
self._global_connection: typing.Optional[Connection] = None
@@ -164,6 +168,30 @@ def _build_url_for_direct_url_scheme(self, scheme: str, database: str) -> str:
164168
"""
165169
return f"{scheme}:///{database}"
166170

171+
@property
172+
def _current_task(self) -> asyncio.Task:
173+
task = asyncio.current_task()
174+
if not task:
175+
raise RuntimeError("No currently active asyncio.Task found")
176+
return task
177+
178+
@property
179+
def _connection(self) -> typing.Optional["Connection"]:
180+
return self._connection_map.get(self._current_task)
181+
182+
@_connection.setter
183+
def _connection(
184+
self, connection: typing.Optional["Connection"]
185+
) -> typing.Optional["Connection"]:
186+
task = self._current_task
187+
188+
if connection is None:
189+
self._connection_map.pop(task, None)
190+
else:
191+
self._connection_map[task] = connection
192+
193+
return self._connection
194+
167195
async def connect(self) -> None:
168196
"""
169197
Establish the connection pool.
@@ -180,7 +208,7 @@ async def connect(self) -> None:
180208
assert self._global_connection is None
181209
assert self._global_transaction is None
182210

183-
self._global_connection = Connection(self._backend)
211+
self._global_connection = Connection(self, self._backend)
184212
self._global_transaction = self._global_connection.transaction(force_rollback=True)
185213

186214
await self._global_transaction.__aenter__()
@@ -202,7 +230,7 @@ async def disconnect(self) -> None:
202230
self._global_transaction = None
203231
self._global_connection = None
204232
else:
205-
self._connection_context = ContextVar("connection_context")
233+
self._connection = None
206234

207235
await self._backend.disconnect()
208236
logger.info(
@@ -274,12 +302,9 @@ def connection(self) -> "Connection":
274302
if self._global_connection is not None:
275303
return self._global_connection
276304

277-
try:
278-
return self._connection_context.get() # type: ignore
279-
except LookupError:
280-
connection = Connection(self._backend)
281-
self._connection_context.set(connection)
282-
return connection
305+
if not self._connection:
306+
self._connection = Connection(self, self._backend)
307+
return self._connection
283308

284309
def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction":
285310
return Transaction(self.connection, force_rollback=force_rollback, **kwargs)
@@ -300,7 +325,8 @@ def _get_backend(self) -> str:
300325

301326

302327
class Connection:
303-
def __init__(self, backend: DatabaseBackend) -> None:
328+
def __init__(self, database: Database, backend: DatabaseBackend) -> None:
329+
self._database = database
304330
self._backend = backend
305331

306332
self._connection_lock = asyncio.Lock()
@@ -334,6 +360,7 @@ async def __aexit__(
334360
self._connection_counter -= 1
335361
if self._connection_counter == 0:
336362
await self._connection.release()
363+
self._database._connection = None
337364

338365
async def fetch_all(
339366
self,
@@ -398,11 +425,6 @@ def connection_callable() -> Connection:
398425
def raw_connection(self) -> typing.Any:
399426
return self._connection.raw_connection
400427

401-
async def run_sync(
402-
self, fn: typing.Callable[..., typing.Any], *arg: typing.Any, **kw: typing.Any
403-
) -> typing.Any:
404-
return await greenlet_spawn(fn, self._connection.raw_connection, *arg, **kw)
405-
406428
@staticmethod
407429
def _build_query(
408430
query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None
@@ -431,6 +453,37 @@ def __init__(
431453
self._force_rollback = force_rollback
432454
self._extra_options = kwargs
433455

456+
@property
457+
def _connection(self) -> "Connection":
458+
# Returns the same connection if called multiple times
459+
return self._connection_callable()
460+
461+
@property
462+
def _transaction(self) -> typing.Optional["TransactionBackend"]:
463+
transactions = ACTIVE_TRANSACTIONS.get()
464+
if transactions is None:
465+
return None
466+
467+
return transactions.get(self, None)
468+
469+
@_transaction.setter
470+
def _transaction(
471+
self, transaction: typing.Optional["TransactionBackend"]
472+
) -> typing.Optional["TransactionBackend"]:
473+
transactions = ACTIVE_TRANSACTIONS.get()
474+
if transactions is None:
475+
transactions = weakref.WeakKeyDictionary()
476+
else:
477+
transactions = transactions.copy()
478+
479+
if transaction is None:
480+
transactions.pop(self, None)
481+
else:
482+
transactions[self] = transaction
483+
484+
ACTIVE_TRANSACTIONS.set(transactions)
485+
return transactions.get(self, None)
486+
434487
async def __aenter__(self) -> "Transaction":
435488
"""
436489
Called when entering `async with database.transaction()`
@@ -471,7 +524,6 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
471524
return wrapper # type: ignore
472525

473526
async def start(self) -> "Transaction":
474-
self._connection = self._connection_callable()
475527
self._transaction = self._connection._connection.transaction()
476528

477529
async with self._connection._transaction_lock:
@@ -485,15 +537,19 @@ async def commit(self) -> None:
485537
async with self._connection._transaction_lock:
486538
assert self._connection._transaction_stack[-1] is self
487539
self._connection._transaction_stack.pop()
540+
assert self._transaction is not None
488541
await self._transaction.commit()
489542
await self._connection.__aexit__()
543+
self._transaction = None
490544

491545
async def rollback(self) -> None:
492546
async with self._connection._transaction_lock:
493547
assert self._connection._transaction_stack[-1] is self
494548
self._connection._transaction_stack.pop()
549+
assert self._transaction is not None
495550
await self._transaction.rollback()
496551
await self._connection.__aexit__()
552+
self._transaction = None
497553

498554

499555
class _EmptyNetloc(str):

docs/connections-and-transactions.md

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ to connect to the database.
3737

3838
## Connecting and disconnecting
3939

40-
You can control the database connect/disconnect, by using it as a async context manager.
40+
You can control the database connection, by using it as a async context manager.
4141

4242
```python
4343
async with Database(DATABASE_URL) as database:
4444
...
4545
```
4646

47-
Or by using explicit connection and disconnection:
47+
Or by using explicit `.connect()` and `disconnect()`:
4848

4949
```python
5050
database = Database(DATABASE_URL)
@@ -246,11 +246,54 @@ async def create_users(request):
246246
...
247247
```
248248

249-
Transaction blocks are managed as task-local state. Nested transactions
250-
are fully supported, and are implemented using database savepoints.
249+
The state of a transaction is liked to the connection used in the currently executing async task.
250+
If you would like to influence an active transaction from another task, the connection must be
251+
shared:
251252

252253
Transaction isolation-level can be specified if the driver backend supports that:
253254

255+
```python
256+
async def add_excitement(connnection: databases.core.Connection, id: int):
257+
await connection.execute(
258+
"UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id",
259+
{"id": id}
260+
)
261+
262+
263+
async with Database(database_url) as database:
264+
async with database.transaction():
265+
# This note won't exist until the transaction closes...
266+
await database.execute(
267+
"INSERT INTO notes(id, text) values (1, 'databases is cool')"
268+
)
269+
# ...but child tasks can use this connection now!
270+
await asyncio.create_task(add_excitement(database.connection(), id=1))
271+
272+
await database.fetch_val("SELECT text FROM notes WHERE id=1")
273+
# ^ returns: "databases is cool!!!"
274+
```
275+
276+
Nested transactions are fully supported, and are implemented using database savepoints:
277+
278+
```python
279+
async with databases.Database(database_url) as db:
280+
async with db.transaction() as outer:
281+
# Do something in the outer transaction
282+
...
283+
284+
# Suppress to prevent influence on the outer transaction
285+
with contextlib.suppress(ValueError):
286+
async with db.transaction():
287+
# Do something in the inner transaction
288+
...
289+
290+
raise ValueError('Abort the inner transaction')
291+
292+
# Observe the results of the outer transaction,
293+
# without effects from the inner transaction.
294+
await db.fetch_all('SELECT * FROM ...')
295+
```
296+
254297
```python
255298
async with database.transaction(isolation="serializable"):
256299
...

docs/release-notes.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Release Notes
22

3+
## 0.5.0
4+
5+
### Fixed
6+
7+
- Patch done in the core of Databases fixing the concurrent usage of connections and transactions.
8+
This patch also affects databases. [#PR 546](https://github.com/encode/databases/pull/546) by [@zevisert](https://github.com/zevisert).
9+
We thank [@zevisert](https://github.com/zevisert) for the fix done in the original project that also affect Databasez.
10+
311
## 0.4.0
412

513
### Changed

docs_src/testclient/tests.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import pytest
88
import saffier
99
from databasez.testclient import DatabaseTestClient
10-
from saffier import fields
10+
from saffier.db.models import fields
11+
1112
from tests.settings import DATABASE_URL
1213

1314
database = DatabaseTestClient(DATABASE_URL, drop_database=True)

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ classifiers = [
3737
"Topic :: Internet :: WWW/HTTP :: HTTP Servers",
3838
"Topic :: Internet :: WWW/HTTP",
3939
]
40-
dependencies = ["nest_asyncio>=1.5.6,<2.0.0", "sqlalchemy>=2.0.16,<2.1"]
40+
dependencies = ["nest_asyncio>=1.5.6,<2.0.0", "sqlalchemy>=2.0.19,<2.1"]
4141
keywords = [
4242
"mysql",
4343
"postgres",
@@ -78,6 +78,7 @@ dev = [
7878
"aiosqlite>=0.18.0,<0.20.0",
7979
"asyncpg>=0.27.0,<0.30.0",
8080
"aioodbc>=0.4.0,<0.5.0",
81+
"ipdb>=0.13.13",
8182
"pre-commit>=2.17.0,<4.0.0",
8283
"psycopg2-binary>=2.9.6,<3.0.0",
8384
"pymysql>=1.0.3,<2.0.0",

0 commit comments

Comments
 (0)