5
5
import typing
6
6
from contextvars import ContextVar
7
7
from types import TracebackType
8
- from typing import Optional
8
+ from typing import Dict , Optional
9
9
from urllib .parse import SplitResult , parse_qsl , unquote , urlsplit
10
10
11
11
from sqlalchemy import text
35
35
36
36
logger = logging .getLogger ("databases" )
37
37
38
+ # Connections are stored as task-local state, but care must be taken to ensure
39
+ # that two database instances in the same task overwrite each other's connections.
40
+ # For this reason, key comprises the database instance and the current task.
41
+ _connection_contextmap : ContextVar [
42
+ Dict [tuple ["Database" , asyncio .Task ], "Connection" ]
43
+ ] = ContextVar ("databases:Connection" )
44
+
45
+
46
+ def _get_connection_contextmap () -> Dict [tuple ["Database" , asyncio .Task ], "Connection" ]:
47
+ connections = _connection_contextmap .get (None )
48
+ if connections is None :
49
+ connections = {}
50
+ _connection_contextmap .set (connections )
51
+ return connections
52
+
38
53
39
54
class Database :
40
55
SUPPORTED_BACKENDS = {
@@ -64,14 +79,6 @@ def __init__(
64
79
assert issubclass (backend_cls , DatabaseBackend )
65
80
self ._backend = backend_cls (self .url , ** self .options )
66
81
67
- # Connections are stored as task-local state, and cannot be garbage collected,
68
- # since the immutable global Context stores a strong reference to each ContextVar
69
- # that is created. We need these local ContextVars since two Database objects
70
- # could run in the same asyncio.Task with connections to different databases.
71
- self ._connection_contextvar : ContextVar [Optional ["Connection" ]] = ContextVar (
72
- f"databases:Database:{ id (self )} "
73
- )
74
-
75
82
# When `force_rollback=True` is used, we use a single global
76
83
# connection, within a transaction that always rolls back.
77
84
self ._global_connection : typing .Optional [Connection ] = None
@@ -119,7 +126,10 @@ async def disconnect(self) -> None:
119
126
self ._global_transaction = None
120
127
self ._global_connection = None
121
128
else :
122
- self ._connection_contextvar .set (None )
129
+ task = asyncio .current_task ()
130
+ connections = _get_connection_contextmap ()
131
+ if (self , task ) in connections :
132
+ del connections [self , task ]
123
133
124
134
await self ._backend .disconnect ()
125
135
logger .info (
@@ -193,12 +203,12 @@ def connection(self) -> "Connection":
193
203
if self ._global_connection is not None :
194
204
return self ._global_connection
195
205
196
- connection = self . _connection_contextvar . get ( None )
197
- if connection is None :
198
- connection = Connection ( self . _backend )
199
- self . _connection_contextvar . set ( connection )
206
+ task = asyncio . current_task ( )
207
+ connections = _get_connection_contextmap ()
208
+ if ( self , task ) not in connections :
209
+ connections [ self , task ] = Connection ( self . _backend )
200
210
201
- return connection
211
+ return connections [ self , task ]
202
212
203
213
def transaction (
204
214
self , * , force_rollback : bool = False , ** kwargs : typing .Any
@@ -339,6 +349,19 @@ def _build_query(
339
349
340
350
_CallableType = typing .TypeVar ("_CallableType" , bound = typing .Callable )
341
351
352
+ _transaction_contextmap : ContextVar [
353
+ Dict ["Transaction" , TransactionBackend ]
354
+ ] = ContextVar ("databases:Transactions" )
355
+
356
+
357
+ def _get_transaction_contextmap () -> Dict ["Transaction" , TransactionBackend ]:
358
+ transactions = _transaction_contextmap .get (None )
359
+ if transactions is None :
360
+ transactions = {}
361
+ _transaction_contextmap .set (transactions )
362
+
363
+ return transactions
364
+
342
365
343
366
class Transaction :
344
367
def __init__ (
@@ -351,15 +374,6 @@ def __init__(
351
374
self ._force_rollback = force_rollback
352
375
self ._extra_options = kwargs
353
376
354
- # This ContextVar can never be garbage collected - similar to the ContextVar
355
- # at Database._connection_contextvar - since the current Context has a strong
356
- # reference to every ContextVar that is created. We need local ContextVars since
357
- # there may be multiple (even nested) transactions in a single asyncio.Task,
358
- # which each need their own unique TransactionBackend object.
359
- self ._transaction_contextvar : ContextVar [
360
- Optional [TransactionBackend ]
361
- ] = ContextVar (f"databases:Transaction:{ id (self )} " )
362
-
363
377
async def __aenter__ (self ) -> "Transaction" :
364
378
"""
365
379
Called when entering `async with database.transaction()`
@@ -402,12 +416,8 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
402
416
async def start (self ) -> "Transaction" :
403
417
connection = self ._connection_callable ()
404
418
transaction = connection ._connection .transaction ()
405
-
406
- # Cannot store returned reset token anywhere, for the same reason
407
- # we need a ContextVar in the first place - `self` is not
408
- # a safe object on which to store references for concurrent code.
409
- self ._transaction_contextvar .set (transaction )
410
-
419
+ transactions = _get_transaction_contextmap ()
420
+ transactions [self ] = transaction
411
421
async with connection ._transaction_lock :
412
422
is_root = not connection ._transaction_stack
413
423
await connection .__aenter__ ()
@@ -417,27 +427,27 @@ async def start(self) -> "Transaction":
417
427
418
428
async def commit (self ) -> None :
419
429
connection = self ._connection_callable ()
420
- transaction = self ._transaction_contextvar .get (None )
430
+ transactions = _get_transaction_contextmap ()
431
+ transaction = transactions .get (self , None )
421
432
assert transaction is not None , "Transaction not found in current task"
422
433
async with connection ._transaction_lock :
423
434
assert connection ._transaction_stack [- 1 ] is self
424
435
connection ._transaction_stack .pop ()
425
436
await transaction .commit ()
426
437
await connection .__aexit__ ()
427
- # Have no reset token, set to None instead
428
- self ._transaction_contextvar .set (None )
438
+ del transactions [self ]
429
439
430
440
async def rollback (self ) -> None :
431
441
connection = self ._connection_callable ()
432
- transaction = self ._transaction_contextvar .get (None )
442
+ transactions = _get_transaction_contextmap ()
443
+ transaction = transactions .get (self , None )
433
444
assert transaction is not None , "Transaction not found in current task"
434
445
async with connection ._transaction_lock :
435
446
assert connection ._transaction_stack [- 1 ] is self
436
447
connection ._transaction_stack .pop ()
437
448
await transaction .rollback ()
438
449
await connection .__aexit__ ()
439
- # Have no reset token, set to None instead
440
- self ._transaction_contextvar .set (None )
450
+ del transactions [self ]
441
451
442
452
443
453
class _EmptyNetloc (str ):
0 commit comments