5
5
import typing
6
6
from contextvars import ContextVar
7
7
from types import TracebackType
8
+ from typing import Optional
8
9
from urllib .parse import SplitResult , parse_qsl , unquote , urlsplit
9
10
10
11
from sqlalchemy import text
@@ -63,8 +64,13 @@ def __init__(
63
64
assert issubclass (backend_cls , DatabaseBackend )
64
65
self ._backend = backend_cls (self .url , ** self .options )
65
66
66
- # Connections are stored as task-local state.
67
- self ._connection_context : ContextVar = ContextVar ("connection_context" )
67
+ # Connections are stored as task-local state, and cannot be garbage collected,
68
+ # since the immutable global Context stores a strong reference to each ContextVar
69
+ # that is created. We need these local ContextVars since two Database objects
70
+ # could run in the same asyncio.Task with connections to different databases.
71
+ self ._connection_contextvar : ContextVar [Optional ["Connection" ]] = ContextVar (
72
+ f"databases:Database:{ id (self )} "
73
+ )
68
74
69
75
# When `force_rollback=True` is used, we use a single global
70
76
# connection, within a transaction that always rolls back.
@@ -113,7 +119,7 @@ async def disconnect(self) -> None:
113
119
self ._global_transaction = None
114
120
self ._global_connection = None
115
121
else :
116
- self ._connection_context = ContextVar ( "connection_context" )
122
+ self ._connection_contextvar . set ( None )
117
123
118
124
await self ._backend .disconnect ()
119
125
logger .info (
@@ -187,12 +193,12 @@ def connection(self) -> "Connection":
187
193
if self ._global_connection is not None :
188
194
return self ._global_connection
189
195
190
- try :
191
- return self ._connection_context .get ()
192
- except LookupError :
196
+ connection = self ._connection_contextvar .get (default = None )
197
+ if connection is None :
193
198
connection = Connection (self ._backend )
194
- self ._connection_context .set (connection )
195
- return connection
199
+ self ._connection_contextvar .set (connection )
200
+
201
+ return connection
196
202
197
203
def transaction (
198
204
self , * , force_rollback : bool = False , ** kwargs : typing .Any
@@ -344,9 +350,15 @@ def __init__(
344
350
self ._connection_callable = connection_callable
345
351
self ._force_rollback = force_rollback
346
352
self ._extra_options = kwargs
347
- self ._transaction_context : ContextVar [TransactionBackend | None ] = ContextVar (
348
- "transaction_context"
349
- )
353
+
354
+ # This ContextVar can never be garbage collected - similar to the ContextVar
355
+ # at Database._connection_contextvar - since the current Context has a strong
356
+ # reference to every ContextVar that is created. We need local ContextVars since
357
+ # there may be multiple (even nested) transactions in a single asyncio.Task,
358
+ # which each need their own unique TransactionBackend object.
359
+ self ._transaction_contextvar : ContextVar [
360
+ Optional [TransactionBackend ]
361
+ ] = ContextVar (f"databases:Transaction:{ id (self )} " )
350
362
351
363
async def __aenter__ (self ) -> "Transaction" :
352
364
"""
@@ -390,7 +402,11 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
390
402
async def start (self ) -> "Transaction" :
391
403
connection = self ._connection_callable ()
392
404
transaction = connection ._connection .transaction ()
393
- self ._transaction_context .set (transaction )
405
+
406
+ # Cannot store returned reset token anywhere, for the same reason
407
+ # we need a ContextVar in the first place - `self` is not
408
+ # a safe object on which to store references for concurrent code.
409
+ self ._transaction_contextvar .set (transaction )
394
410
395
411
async with connection ._transaction_lock :
396
412
is_root = not connection ._transaction_stack
@@ -401,25 +417,27 @@ async def start(self) -> "Transaction":
401
417
402
418
async def commit (self ) -> None :
403
419
connection = self ._connection_callable ()
404
- transaction = self ._transaction_context .get ()
420
+ transaction = self ._transaction_contextvar .get (default = None )
405
421
assert transaction is not None , "Transaction not found in current task"
406
422
async with connection ._transaction_lock :
407
423
assert connection ._transaction_stack [- 1 ] is self
408
424
connection ._transaction_stack .pop ()
409
425
await transaction .commit ()
410
426
await connection .__aexit__ ()
411
- self ._transaction_context .set (None )
427
+ # Have no reset token, set to None instead
428
+ self ._transaction_contextvar .set (None )
412
429
413
430
async def rollback (self ) -> None :
414
431
connection = self ._connection_callable ()
415
- transaction = self ._transaction_context .get ()
432
+ transaction = self ._transaction_contextvar .get (default = None )
416
433
assert transaction is not None , "Transaction not found in current task"
417
434
async with connection ._transaction_lock :
418
435
assert connection ._transaction_stack [- 1 ] is self
419
436
connection ._transaction_stack .pop ()
420
437
await transaction .rollback ()
421
438
await connection .__aexit__ ()
422
- self ._transaction_context .set (None )
439
+ # Have no reset token, set to None instead
440
+ self ._transaction_contextvar .set (None )
423
441
424
442
425
443
class _EmptyNetloc (str ):
0 commit comments