Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion databasez/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from databasez.core import Database, DatabaseURL

__version__ = "0.4.0"
__version__ = "0.5.0"

__all__ = ["Database", "DatabaseURL"]
96 changes: 76 additions & 20 deletions databasez/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import functools
import logging
import typing
import weakref
from contextvars import ContextVar
from types import TracebackType
from urllib.parse import SplitResult, parse_qsl, unquote, urlencode, urlsplit

from sqlalchemy import text
from sqlalchemy.sql import ClauseElement
from sqlalchemy.util._concurrency_py3k import greenlet_spawn

from databasez.importer import import_from_string
from databasez.interfaces import DatabaseBackend, Record
from databasez.interfaces import DatabaseBackend, Record, TransactionBackend

if typing.TYPE_CHECKING:
from databasez.types import DictAny
Expand All @@ -35,6 +35,11 @@
logger = logging.getLogger("databasez")


ACTIVE_TRANSACTIONS: ContextVar[
typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"]
] = ContextVar("databasez:active_transactions", default=None)


class Database:
"""
An abstraction on the top of the EncodeORM databases.Database object.
Expand Down Expand Up @@ -72,6 +77,7 @@ class Database:
}
DIRECT_URL_SCHEME = {"sqlite"}
MANDATORY_FIELDS = ["host", "port", "user", "database"]
_connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"

def __init__(
self,
Expand All @@ -92,6 +98,7 @@ def __init__(
self.url = DatabaseURL(_url) # type: ignore
self.options = options
self.is_connected = False
self._connection_map = weakref.WeakKeyDictionary()

self._force_rollback = force_rollback

Expand All @@ -100,9 +107,6 @@ def __init__(
assert issubclass(backend_cls, DatabaseBackend)
self._backend = backend_cls(self.url, **self.options)

# Connections are stored as task-local state.
self._connection_context: ContextVar = ContextVar("connection_context")

# When `force_rollback=True` is used, we use a single global
# connection, within a transaction that always rolls back.
self._global_connection: typing.Optional[Connection] = None
Expand Down Expand Up @@ -164,6 +168,30 @@ def _build_url_for_direct_url_scheme(self, scheme: str, database: str) -> str:
"""
return f"{scheme}:///{database}"

@property
def _current_task(self) -> asyncio.Task:
task = asyncio.current_task()
if not task:
raise RuntimeError("No currently active asyncio.Task found")
return task

@property
def _connection(self) -> typing.Optional["Connection"]:
return self._connection_map.get(self._current_task)

@_connection.setter
def _connection(
self, connection: typing.Optional["Connection"]
) -> typing.Optional["Connection"]:
task = self._current_task

if connection is None:
self._connection_map.pop(task, None)
else:
self._connection_map[task] = connection

return self._connection

async def connect(self) -> None:
"""
Establish the connection pool.
Expand All @@ -180,7 +208,7 @@ async def connect(self) -> None:
assert self._global_connection is None
assert self._global_transaction is None

self._global_connection = Connection(self._backend)
self._global_connection = Connection(self, self._backend)
self._global_transaction = self._global_connection.transaction(force_rollback=True)

await self._global_transaction.__aenter__()
Expand All @@ -202,7 +230,7 @@ async def disconnect(self) -> None:
self._global_transaction = None
self._global_connection = None
else:
self._connection_context = ContextVar("connection_context")
self._connection = None

await self._backend.disconnect()
logger.info(
Expand Down Expand Up @@ -274,12 +302,9 @@ def connection(self) -> "Connection":
if self._global_connection is not None:
return self._global_connection

try:
return self._connection_context.get() # type: ignore
except LookupError:
connection = Connection(self._backend)
self._connection_context.set(connection)
return connection
if not self._connection:
self._connection = Connection(self, self._backend)
return self._connection

def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction":
return Transaction(self.connection, force_rollback=force_rollback, **kwargs)
Expand All @@ -300,7 +325,8 @@ def _get_backend(self) -> str:


class Connection:
def __init__(self, backend: DatabaseBackend) -> None:
def __init__(self, database: Database, backend: DatabaseBackend) -> None:
self._database = database
self._backend = backend

self._connection_lock = asyncio.Lock()
Expand Down Expand Up @@ -334,6 +360,7 @@ async def __aexit__(
self._connection_counter -= 1
if self._connection_counter == 0:
await self._connection.release()
self._database._connection = None

async def fetch_all(
self,
Expand Down Expand Up @@ -398,11 +425,6 @@ def connection_callable() -> Connection:
def raw_connection(self) -> typing.Any:
return self._connection.raw_connection

async def run_sync(
self, fn: typing.Callable[..., typing.Any], *arg: typing.Any, **kw: typing.Any
) -> typing.Any:
return await greenlet_spawn(fn, self._connection.raw_connection, *arg, **kw)

@staticmethod
def _build_query(
query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None
Expand Down Expand Up @@ -431,6 +453,37 @@ def __init__(
self._force_rollback = force_rollback
self._extra_options = kwargs

@property
def _connection(self) -> "Connection":
# Returns the same connection if called multiple times
return self._connection_callable()

@property
def _transaction(self) -> typing.Optional["TransactionBackend"]:
transactions = ACTIVE_TRANSACTIONS.get()
if transactions is None:
return None

return transactions.get(self, None)

@_transaction.setter
def _transaction(
self, transaction: typing.Optional["TransactionBackend"]
) -> typing.Optional["TransactionBackend"]:
transactions = ACTIVE_TRANSACTIONS.get()
if transactions is None:
transactions = weakref.WeakKeyDictionary()
else:
transactions = transactions.copy()

if transaction is None:
transactions.pop(self, None)
else:
transactions[self] = transaction

ACTIVE_TRANSACTIONS.set(transactions)
return transactions.get(self, None)

async def __aenter__(self) -> "Transaction":
"""
Called when entering `async with database.transaction()`
Expand Down Expand Up @@ -471,7 +524,6 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return wrapper # type: ignore

async def start(self) -> "Transaction":
self._connection = self._connection_callable()
self._transaction = self._connection._connection.transaction()

async with self._connection._transaction_lock:
Expand All @@ -485,15 +537,19 @@ async def commit(self) -> None:
async with self._connection._transaction_lock:
assert self._connection._transaction_stack[-1] is self
self._connection._transaction_stack.pop()
assert self._transaction is not None
await self._transaction.commit()
await self._connection.__aexit__()
self._transaction = None

async def rollback(self) -> None:
async with self._connection._transaction_lock:
assert self._connection._transaction_stack[-1] is self
self._connection._transaction_stack.pop()
assert self._transaction is not None
await self._transaction.rollback()
await self._connection.__aexit__()
self._transaction = None


class _EmptyNetloc(str):
Expand Down
51 changes: 47 additions & 4 deletions docs/connections-and-transactions.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ to connect to the database.

## Connecting and disconnecting

You can control the database connect/disconnect, by using it as a async context manager.
You can control the database connection, by using it as a async context manager.

```python
async with Database(DATABASE_URL) as database:
...
```

Or by using explicit connection and disconnection:
Or by using explicit `.connect()` and `disconnect()`:

```python
database = Database(DATABASE_URL)
Expand Down Expand Up @@ -246,11 +246,54 @@ async def create_users(request):
...
```

Transaction blocks are managed as task-local state. Nested transactions
are fully supported, and are implemented using database savepoints.
The state of a transaction is liked to the connection used in the currently executing async task.
If you would like to influence an active transaction from another task, the connection must be
shared:

Transaction isolation-level can be specified if the driver backend supports that:

```python
async def add_excitement(connnection: databases.core.Connection, id: int):
await connection.execute(
"UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id",
{"id": id}
)


async with Database(database_url) as database:
async with database.transaction():
# This note won't exist until the transaction closes...
await database.execute(
"INSERT INTO notes(id, text) values (1, 'databases is cool')"
)
# ...but child tasks can use this connection now!
await asyncio.create_task(add_excitement(database.connection(), id=1))

await database.fetch_val("SELECT text FROM notes WHERE id=1")
# ^ returns: "databases is cool!!!"
```

Nested transactions are fully supported, and are implemented using database savepoints:

```python
async with databases.Database(database_url) as db:
async with db.transaction() as outer:
# Do something in the outer transaction
...

# Suppress to prevent influence on the outer transaction
with contextlib.suppress(ValueError):
async with db.transaction():
# Do something in the inner transaction
...

raise ValueError('Abort the inner transaction')

# Observe the results of the outer transaction,
# without effects from the inner transaction.
await db.fetch_all('SELECT * FROM ...')
```

```python
async with database.transaction(isolation="serializable"):
...
Expand Down
8 changes: 8 additions & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Release Notes

## 0.5.0

### Fixed

- Patch done in the core of Databases fixing the concurrent usage of connections and transactions.
This patch also affects databases. [#PR 546](https://github.com/encode/databases/pull/546) by [@zevisert](https://github.com/zevisert).
We thank [@zevisert](https://github.com/zevisert) for the fix done in the original project that also affect Databasez.

## 0.4.0

### Changed
Expand Down
3 changes: 2 additions & 1 deletion docs_src/testclient/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytest
import saffier
from databasez.testclient import DatabaseTestClient
from saffier import fields
from saffier.db.models import fields

from tests.settings import DATABASE_URL

database = DatabaseTestClient(DATABASE_URL, drop_database=True)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ classifiers = [
"Topic :: Internet :: WWW/HTTP :: HTTP Servers",
"Topic :: Internet :: WWW/HTTP",
]
dependencies = ["nest_asyncio>=1.5.6,<2.0.0", "sqlalchemy>=2.0.16,<2.1"]
dependencies = ["nest_asyncio>=1.5.6,<2.0.0", "sqlalchemy>=2.0.19,<2.1"]
keywords = [
"mysql",
"postgres",
Expand Down Expand Up @@ -78,6 +78,7 @@ dev = [
"aiosqlite>=0.18.0,<0.20.0",
"asyncpg>=0.27.0,<0.30.0",
"aioodbc>=0.4.0,<0.5.0",
"ipdb>=0.13.13",
"pre-commit>=2.17.0,<4.0.0",
"psycopg2-binary>=2.9.6,<3.0.0",
"pymysql>=1.0.3,<2.0.0",
Expand Down
Loading