33import functools
44import logging
55import typing
6+ import weakref
67from contextvars import ContextVar
78from types import TracebackType
89from urllib .parse import SplitResult , parse_qsl , unquote , urlencode , urlsplit
910
1011from sqlalchemy import text
1112from sqlalchemy .sql import ClauseElement
12- from sqlalchemy .util ._concurrency_py3k import greenlet_spawn
1313
1414from databasez .importer import import_from_string
15- from databasez .interfaces import DatabaseBackend , Record
15+ from databasez .interfaces import DatabaseBackend , Record , TransactionBackend
1616
1717if typing .TYPE_CHECKING :
1818 from databasez .types import DictAny
3535logger = logging .getLogger ("databasez" )
3636
3737
38+ ACTIVE_TRANSACTIONS : ContextVar [
39+ typing .Optional ["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']" ]
40+ ] = ContextVar ("databasez:active_transactions" , default = None )
41+
42+
3843class Database :
3944 """
4045 An abstraction on the top of the EncodeORM databases.Database object.
@@ -72,6 +77,7 @@ class Database:
7277 }
7378 DIRECT_URL_SCHEME = {"sqlite" }
7479 MANDATORY_FIELDS = ["host" , "port" , "user" , "database" ]
80+ _connection_map : "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"
7581
7682 def __init__ (
7783 self ,
@@ -92,6 +98,7 @@ def __init__(
9298 self .url = DatabaseURL (_url ) # type: ignore
9399 self .options = options
94100 self .is_connected = False
101+ self ._connection_map = weakref .WeakKeyDictionary ()
95102
96103 self ._force_rollback = force_rollback
97104
@@ -100,9 +107,6 @@ def __init__(
100107 assert issubclass (backend_cls , DatabaseBackend )
101108 self ._backend = backend_cls (self .url , ** self .options )
102109
103- # Connections are stored as task-local state.
104- self ._connection_context : ContextVar = ContextVar ("connection_context" )
105-
106110 # When `force_rollback=True` is used, we use a single global
107111 # connection, within a transaction that always rolls back.
108112 self ._global_connection : typing .Optional [Connection ] = None
@@ -164,6 +168,30 @@ def _build_url_for_direct_url_scheme(self, scheme: str, database: str) -> str:
164168 """
165169 return f"{ scheme } :///{ database } "
166170
171+ @property
172+ def _current_task (self ) -> asyncio .Task :
173+ task = asyncio .current_task ()
174+ if not task :
175+ raise RuntimeError ("No currently active asyncio.Task found" )
176+ return task
177+
178+ @property
179+ def _connection (self ) -> typing .Optional ["Connection" ]:
180+ return self ._connection_map .get (self ._current_task )
181+
182+ @_connection .setter
183+ def _connection (
184+ self , connection : typing .Optional ["Connection" ]
185+ ) -> typing .Optional ["Connection" ]:
186+ task = self ._current_task
187+
188+ if connection is None :
189+ self ._connection_map .pop (task , None )
190+ else :
191+ self ._connection_map [task ] = connection
192+
193+ return self ._connection
194+
167195 async def connect (self ) -> None :
168196 """
169197 Establish the connection pool.
@@ -180,7 +208,7 @@ async def connect(self) -> None:
180208 assert self ._global_connection is None
181209 assert self ._global_transaction is None
182210
183- self ._global_connection = Connection (self ._backend )
211+ self ._global_connection = Connection (self , self ._backend )
184212 self ._global_transaction = self ._global_connection .transaction (force_rollback = True )
185213
186214 await self ._global_transaction .__aenter__ ()
@@ -202,7 +230,7 @@ async def disconnect(self) -> None:
202230 self ._global_transaction = None
203231 self ._global_connection = None
204232 else :
205- self ._connection_context = ContextVar ( "connection_context" )
233+ self ._connection = None
206234
207235 await self ._backend .disconnect ()
208236 logger .info (
@@ -274,12 +302,9 @@ def connection(self) -> "Connection":
274302 if self ._global_connection is not None :
275303 return self ._global_connection
276304
277- try :
278- return self ._connection_context .get () # type: ignore
279- except LookupError :
280- connection = Connection (self ._backend )
281- self ._connection_context .set (connection )
282- return connection
305+ if not self ._connection :
306+ self ._connection = Connection (self , self ._backend )
307+ return self ._connection
283308
284309 def transaction (self , * , force_rollback : bool = False , ** kwargs : typing .Any ) -> "Transaction" :
285310 return Transaction (self .connection , force_rollback = force_rollback , ** kwargs )
@@ -300,7 +325,8 @@ def _get_backend(self) -> str:
300325
301326
302327class Connection :
303- def __init__ (self , backend : DatabaseBackend ) -> None :
328+ def __init__ (self , database : Database , backend : DatabaseBackend ) -> None :
329+ self ._database = database
304330 self ._backend = backend
305331
306332 self ._connection_lock = asyncio .Lock ()
@@ -334,6 +360,7 @@ async def __aexit__(
334360 self ._connection_counter -= 1
335361 if self ._connection_counter == 0 :
336362 await self ._connection .release ()
363+ self ._database ._connection = None
337364
338365 async def fetch_all (
339366 self ,
@@ -398,11 +425,6 @@ def connection_callable() -> Connection:
398425 def raw_connection (self ) -> typing .Any :
399426 return self ._connection .raw_connection
400427
401- async def run_sync (
402- self , fn : typing .Callable [..., typing .Any ], * arg : typing .Any , ** kw : typing .Any
403- ) -> typing .Any :
404- return await greenlet_spawn (fn , self ._connection .raw_connection , * arg , ** kw )
405-
406428 @staticmethod
407429 def _build_query (
408430 query : typing .Union [ClauseElement , str ], values : typing .Optional [dict ] = None
@@ -431,6 +453,37 @@ def __init__(
431453 self ._force_rollback = force_rollback
432454 self ._extra_options = kwargs
433455
456+ @property
457+ def _connection (self ) -> "Connection" :
458+ # Returns the same connection if called multiple times
459+ return self ._connection_callable ()
460+
461+ @property
462+ def _transaction (self ) -> typing .Optional ["TransactionBackend" ]:
463+ transactions = ACTIVE_TRANSACTIONS .get ()
464+ if transactions is None :
465+ return None
466+
467+ return transactions .get (self , None )
468+
469+ @_transaction .setter
470+ def _transaction (
471+ self , transaction : typing .Optional ["TransactionBackend" ]
472+ ) -> typing .Optional ["TransactionBackend" ]:
473+ transactions = ACTIVE_TRANSACTIONS .get ()
474+ if transactions is None :
475+ transactions = weakref .WeakKeyDictionary ()
476+ else :
477+ transactions = transactions .copy ()
478+
479+ if transaction is None :
480+ transactions .pop (self , None )
481+ else :
482+ transactions [self ] = transaction
483+
484+ ACTIVE_TRANSACTIONS .set (transactions )
485+ return transactions .get (self , None )
486+
434487 async def __aenter__ (self ) -> "Transaction" :
435488 """
436489 Called when entering `async with database.transaction()`
@@ -471,7 +524,6 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
471524 return wrapper # type: ignore
472525
473526 async def start (self ) -> "Transaction" :
474- self ._connection = self ._connection_callable ()
475527 self ._transaction = self ._connection ._connection .transaction ()
476528
477529 async with self ._connection ._transaction_lock :
@@ -485,15 +537,19 @@ async def commit(self) -> None:
485537 async with self ._connection ._transaction_lock :
486538 assert self ._connection ._transaction_stack [- 1 ] is self
487539 self ._connection ._transaction_stack .pop ()
540+ assert self ._transaction is not None
488541 await self ._transaction .commit ()
489542 await self ._connection .__aexit__ ()
543+ self ._transaction = None
490544
491545 async def rollback (self ) -> None :
492546 async with self ._connection ._transaction_lock :
493547 assert self ._connection ._transaction_stack [- 1 ] is self
494548 self ._connection ._transaction_stack .pop ()
549+ assert self ._transaction is not None
495550 await self ._transaction .rollback ()
496551 await self ._connection .__aexit__ ()
552+ self ._transaction = None
497553
498554
499555class _EmptyNetloc (str ):
0 commit comments