Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 574626a

Browse files
committed
feat: use ContextVar[dict] to track connections and transactions per task
1 parent 90e5350 commit 574626a

File tree

1 file changed

+46
-36
lines changed

1 file changed

+46
-36
lines changed

databases/core.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import typing
66
from contextvars import ContextVar
77
from types import TracebackType
8-
from typing import Optional
8+
from typing import Dict, Optional
99
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit
1010

1111
from sqlalchemy import text
@@ -35,6 +35,21 @@
3535

3636
logger = logging.getLogger("databases")
3737

38+
# Connections are stored as task-local state, but care must be taken to ensure
39+
# that two database instances in the same task overwrite each other's connections.
40+
# For this reason, key comprises the database instance and the current task.
41+
_connection_contextmap: ContextVar[
42+
Dict[tuple["Database", asyncio.Task], "Connection"]
43+
] = ContextVar("databases:Connection")
44+
45+
46+
def _get_connection_contextmap() -> Dict[tuple["Database", asyncio.Task], "Connection"]:
47+
connections = _connection_contextmap.get(None)
48+
if connections is None:
49+
connections = {}
50+
_connection_contextmap.set(connections)
51+
return connections
52+
3853

3954
class Database:
4055
SUPPORTED_BACKENDS = {
@@ -64,14 +79,6 @@ def __init__(
6479
assert issubclass(backend_cls, DatabaseBackend)
6580
self._backend = backend_cls(self.url, **self.options)
6681

67-
# Connections are stored as task-local state, and cannot be garbage collected,
68-
# since the immutable global Context stores a strong reference to each ContextVar
69-
# that is created. We need these local ContextVars since two Database objects
70-
# could run in the same asyncio.Task with connections to different databases.
71-
self._connection_contextvar: ContextVar[Optional["Connection"]] = ContextVar(
72-
f"databases:Database:{id(self)}"
73-
)
74-
7582
# When `force_rollback=True` is used, we use a single global
7683
# connection, within a transaction that always rolls back.
7784
self._global_connection: typing.Optional[Connection] = None
@@ -119,7 +126,10 @@ async def disconnect(self) -> None:
119126
self._global_transaction = None
120127
self._global_connection = None
121128
else:
122-
self._connection_contextvar.set(None)
129+
task = asyncio.current_task()
130+
connections = _get_connection_contextmap()
131+
if (self, task) in connections:
132+
del connections[self, task]
123133

124134
await self._backend.disconnect()
125135
logger.info(
@@ -193,12 +203,12 @@ def connection(self) -> "Connection":
193203
if self._global_connection is not None:
194204
return self._global_connection
195205

196-
connection = self._connection_contextvar.get(None)
197-
if connection is None:
198-
connection = Connection(self._backend)
199-
self._connection_contextvar.set(connection)
206+
task = asyncio.current_task()
207+
connections = _get_connection_contextmap()
208+
if (self, task) not in connections:
209+
connections[self, task] = Connection(self._backend)
200210

201-
return connection
211+
return connections[self, task]
202212

203213
def transaction(
204214
self, *, force_rollback: bool = False, **kwargs: typing.Any
@@ -339,6 +349,19 @@ def _build_query(
339349

340350
_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable)
341351

352+
_transaction_contextmap: ContextVar[
353+
Dict["Transaction", TransactionBackend]
354+
] = ContextVar("databases:Transactions")
355+
356+
357+
def _get_transaction_contextmap() -> Dict["Transaction", TransactionBackend]:
358+
transactions = _transaction_contextmap.get(None)
359+
if transactions is None:
360+
transactions = {}
361+
_transaction_contextmap.set(transactions)
362+
363+
return transactions
364+
342365

343366
class Transaction:
344367
def __init__(
@@ -351,15 +374,6 @@ def __init__(
351374
self._force_rollback = force_rollback
352375
self._extra_options = kwargs
353376

354-
# This ContextVar can never be garbage collected - similar to the ContextVar
355-
# at Database._connection_contextvar - since the current Context has a strong
356-
# reference to every ContextVar that is created. We need local ContextVars since
357-
# there may be multiple (even nested) transactions in a single asyncio.Task,
358-
# which each need their own unique TransactionBackend object.
359-
self._transaction_contextvar: ContextVar[
360-
Optional[TransactionBackend]
361-
] = ContextVar(f"databases:Transaction:{id(self)}")
362-
363377
async def __aenter__(self) -> "Transaction":
364378
"""
365379
Called when entering `async with database.transaction()`
@@ -402,12 +416,8 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
402416
async def start(self) -> "Transaction":
403417
connection = self._connection_callable()
404418
transaction = connection._connection.transaction()
405-
406-
# Cannot store returned reset token anywhere, for the same reason
407-
# we need a ContextVar in the first place - `self` is not
408-
# a safe object on which to store references for concurrent code.
409-
self._transaction_contextvar.set(transaction)
410-
419+
transactions = _get_transaction_contextmap()
420+
transactions[self] = transaction
411421
async with connection._transaction_lock:
412422
is_root = not connection._transaction_stack
413423
await connection.__aenter__()
@@ -417,27 +427,27 @@ async def start(self) -> "Transaction":
417427

418428
async def commit(self) -> None:
419429
connection = self._connection_callable()
420-
transaction = self._transaction_contextvar.get(None)
430+
transactions = _get_transaction_contextmap()
431+
transaction = transactions.get(self, None)
421432
assert transaction is not None, "Transaction not found in current task"
422433
async with connection._transaction_lock:
423434
assert connection._transaction_stack[-1] is self
424435
connection._transaction_stack.pop()
425436
await transaction.commit()
426437
await connection.__aexit__()
427-
# Have no reset token, set to None instead
428-
self._transaction_contextvar.set(None)
438+
del transactions[self]
429439

430440
async def rollback(self) -> None:
431441
connection = self._connection_callable()
432-
transaction = self._transaction_contextvar.get(None)
442+
transactions = _get_transaction_contextmap()
443+
transaction = transactions.get(self, None)
433444
assert transaction is not None, "Transaction not found in current task"
434445
async with connection._transaction_lock:
435446
assert connection._transaction_stack[-1] is self
436447
connection._transaction_stack.pop()
437448
await transaction.rollback()
438449
await connection.__aexit__()
439-
# Have no reset token, set to None instead
440-
self._transaction_contextvar.set(None)
450+
del transactions[self]
441451

442452

443453
class _EmptyNetloc(str):

0 commit comments

Comments
 (0)