|
11 | 11 | from sqlalchemy.sql import ClauseElement
|
12 | 12 |
|
13 | 13 | from databases.importer import import_from_string
|
14 |
| -from databases.interfaces import DatabaseBackend, Record |
| 14 | +from databases.interfaces import DatabaseBackend, Record, TransactionBackend |
15 | 15 |
|
16 | 16 | try: # pragma: no cover
|
17 | 17 | import click
|
@@ -344,6 +344,9 @@ def __init__(
|
344 | 344 | self._connection_callable = connection_callable
|
345 | 345 | self._force_rollback = force_rollback
|
346 | 346 | self._extra_options = kwargs
|
| 347 | + self._transaction_context: ContextVar[TransactionBackend | None] = ContextVar( |
| 348 | + "transaction_context" |
| 349 | + ) |
347 | 350 |
|
348 | 351 | async def __aenter__(self) -> "Transaction":
|
349 | 352 | """
|
@@ -385,31 +388,38 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
|
385 | 388 | return wrapper # type: ignore
|
386 | 389 |
|
387 | 390 | async def start(self) -> "Transaction":
|
388 |
| - self._connection = self._connection_callable() |
389 |
| - self._transaction = self._connection._connection.transaction() |
390 |
| - |
391 |
| - async with self._connection._transaction_lock: |
392 |
| - is_root = not self._connection._transaction_stack |
393 |
| - await self._connection.__aenter__() |
394 |
| - await self._transaction.start( |
395 |
| - is_root=is_root, extra_options=self._extra_options |
396 |
| - ) |
397 |
| - self._connection._transaction_stack.append(self) |
| 391 | + connection = self._connection_callable() |
| 392 | + transaction = connection._connection.transaction() |
| 393 | + self._transaction_context.set(transaction) |
| 394 | + |
| 395 | + async with connection._transaction_lock: |
| 396 | + is_root = not connection._transaction_stack |
| 397 | + await connection.__aenter__() |
| 398 | + await transaction.start(is_root=is_root, extra_options=self._extra_options) |
| 399 | + connection._transaction_stack.append(self) |
398 | 400 | return self
|
399 | 401 |
|
400 | 402 | async def commit(self) -> None:
|
401 |
| - async with self._connection._transaction_lock: |
402 |
| - assert self._connection._transaction_stack[-1] is self |
403 |
| - self._connection._transaction_stack.pop() |
404 |
| - await self._transaction.commit() |
405 |
| - await self._connection.__aexit__() |
| 403 | + connection = self._connection_callable() |
| 404 | + transaction = self._transaction_context.get() |
| 405 | + assert transaction is not None, "Transaction not found in current task" |
| 406 | + async with connection._transaction_lock: |
| 407 | + assert connection._transaction_stack[-1] is self |
| 408 | + connection._transaction_stack.pop() |
| 409 | + await transaction.commit() |
| 410 | + await connection.__aexit__() |
| 411 | + self._transaction_context.set(None) |
406 | 412 |
|
407 | 413 | async def rollback(self) -> None:
|
408 |
| - async with self._connection._transaction_lock: |
409 |
| - assert self._connection._transaction_stack[-1] is self |
410 |
| - self._connection._transaction_stack.pop() |
411 |
| - await self._transaction.rollback() |
412 |
| - await self._connection.__aexit__() |
| 414 | + connection = self._connection_callable() |
| 415 | + transaction = self._transaction_context.get() |
| 416 | + assert transaction is not None, "Transaction not found in current task" |
| 417 | + async with connection._transaction_lock: |
| 418 | + assert connection._transaction_stack[-1] is self |
| 419 | + connection._transaction_stack.pop() |
| 420 | + await transaction.rollback() |
| 421 | + await connection.__aexit__() |
| 422 | + self._transaction_context.set(None) |
413 | 423 |
|
414 | 424 |
|
415 | 425 | class _EmptyNetloc(str):
|
|
0 commit comments