Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit bea6629

Browse files
committed
refactor: rename contextvar class attributes, add some explaination comments
1 parent 90c33da commit bea6629

File tree

1 file changed

+34
-16
lines changed

1 file changed

+34
-16
lines changed

databases/core.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import typing
66
from contextvars import ContextVar
77
from types import TracebackType
8+
from typing import Optional
89
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit
910

1011
from sqlalchemy import text
@@ -63,8 +64,13 @@ def __init__(
6364
assert issubclass(backend_cls, DatabaseBackend)
6465
self._backend = backend_cls(self.url, **self.options)
6566

66-
# Connections are stored as task-local state.
67-
self._connection_context: ContextVar = ContextVar("connection_context")
67+
# Connections are stored as task-local state, and cannot be garbage collected,
68+
# since the immutable global Context stores a strong reference to each ContextVar
69+
# that is created. We need these local ContextVars since two Database objects
70+
# could run in the same asyncio.Task with connections to different databases.
71+
self._connection_contextvar: ContextVar[Optional["Connection"]] = ContextVar(
72+
f"databases:Database:{id(self)}"
73+
)
6874

6975
# When `force_rollback=True` is used, we use a single global
7076
# connection, within a transaction that always rolls back.
@@ -113,7 +119,7 @@ async def disconnect(self) -> None:
113119
self._global_transaction = None
114120
self._global_connection = None
115121
else:
116-
self._connection_context = ContextVar("connection_context")
122+
self._connection_contextvar.set(None)
117123

118124
await self._backend.disconnect()
119125
logger.info(
@@ -187,12 +193,12 @@ def connection(self) -> "Connection":
187193
if self._global_connection is not None:
188194
return self._global_connection
189195

190-
try:
191-
return self._connection_context.get()
192-
except LookupError:
196+
connection = self._connection_contextvar.get(default=None)
197+
if connection is None:
193198
connection = Connection(self._backend)
194-
self._connection_context.set(connection)
195-
return connection
199+
self._connection_contextvar.set(connection)
200+
201+
return connection
196202

197203
def transaction(
198204
self, *, force_rollback: bool = False, **kwargs: typing.Any
@@ -344,9 +350,15 @@ def __init__(
344350
self._connection_callable = connection_callable
345351
self._force_rollback = force_rollback
346352
self._extra_options = kwargs
347-
self._transaction_context: ContextVar[TransactionBackend | None] = ContextVar(
348-
"transaction_context"
349-
)
353+
354+
# This ContextVar can never be garbage collected - similar to the ContextVar
355+
# at Database._connection_contextvar - since the current Context has a strong
356+
# reference to every ContextVar that is created. We need local ContextVars since
357+
# there may be multiple (even nested) transactions in a single asyncio.Task,
358+
# which each need their own unique TransactionBackend object.
359+
self._transaction_contextvar: ContextVar[
360+
Optional[TransactionBackend]
361+
] = ContextVar(f"databases:Transaction:{id(self)}")
350362

351363
async def __aenter__(self) -> "Transaction":
352364
"""
@@ -390,7 +402,11 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
390402
async def start(self) -> "Transaction":
391403
connection = self._connection_callable()
392404
transaction = connection._connection.transaction()
393-
self._transaction_context.set(transaction)
405+
406+
# Cannot store returned reset token anywhere, for the same reason
407+
# we need a ContextVar in the first place - `self` is not
408+
# a safe object on which to store references for concurrent code.
409+
self._transaction_contextvar.set(transaction)
394410

395411
async with connection._transaction_lock:
396412
is_root = not connection._transaction_stack
@@ -401,25 +417,27 @@ async def start(self) -> "Transaction":
401417

402418
async def commit(self) -> None:
403419
connection = self._connection_callable()
404-
transaction = self._transaction_context.get()
420+
transaction = self._transaction_contextvar.get(default=None)
405421
assert transaction is not None, "Transaction not found in current task"
406422
async with connection._transaction_lock:
407423
assert connection._transaction_stack[-1] is self
408424
connection._transaction_stack.pop()
409425
await transaction.commit()
410426
await connection.__aexit__()
411-
self._transaction_context.set(None)
427+
# Have no reset token, set to None instead
428+
self._transaction_contextvar.set(None)
412429

413430
async def rollback(self) -> None:
414431
connection = self._connection_callable()
415-
transaction = self._transaction_context.get()
432+
transaction = self._transaction_contextvar.get(default=None)
416433
assert transaction is not None, "Transaction not found in current task"
417434
async with connection._transaction_lock:
418435
assert connection._transaction_stack[-1] is self
419436
connection._transaction_stack.pop()
420437
await transaction.rollback()
421438
await connection.__aexit__()
422-
self._transaction_context.set(None)
439+
# Have no reset token, set to None instead
440+
self._transaction_contextvar.set(None)
423441

424442

425443
class _EmptyNetloc(str):

0 commit comments

Comments
 (0)