Skip to content

Add support for db fixture (test inside transaction) for asyncio tests #1223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions docs/database.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,88 @@ select using an argument to the ``django_db`` mark::
def test_spam():
pass # test relying on transactions


Async tests and database transactions
-------------------------------------

``pytest-django`` supports async tests that use Django's async ORM APIs.
This requires the `pytest-asyncio <https://github.com/pytest-dev/pytest-asyncio>`_
plugin and marking your tests appropriately.

Requirements
------------

- Install ``pytest-asyncio``.
- Mark async tests with both ``@pytest.mark.asyncio`` and
``@pytest.mark.django_db`` (or request the ``db``/``transactional_db`` fixtures).

Example (async ORM with transactional rollback per test)::

import pytest

@pytest.mark.asyncio
@pytest.mark.django_db
async def test_async_db_is_isolated():
assert await Item.objects.acount() == 0
await Item.objects.acreate(name="example")
assert await Item.objects.acount() == 1
# changes are rolled back after the test

.. _`async-db-behavior`:

Behavior of ``db`` in async tests
---------------------------------

Tests using ``db`` wrap each test in a transaction and roll that transaction back at the end
(like ``django.test.TestCase``). In Django, transactions are bound to the database
connection, which is unique per thread. This means that all your database changes
must be made within the same thread to ensure they are rolled back before the next test.

Django Async ORM calls, as of writing, use the ``asgiref.sync.sync_to_async``
decorator to run the ORM calls on a dedicated thread executor.

For async tests, pytest-django ensures the transaction
setup/teardown happens via ``asgiref.sync.sync_to_async``, which means the transaction is started & run on the
same thread on which async orm calls inside your test, like ``aget()`` are made. This ensures your test code
can safely modify the database using the async calls, as all its queries will be rolled back after the test.

Tests using ``transactional_db`` flush the database between tests. This means that no matter in which thread
your test modifies the database, the changes will be removed after the test. This means you can avoid thinking
about sync/async database access if your test uses ``transactional_db``, at the cost of slower tests:
A flush is generally slower than rolling back a transaction.

.. _`db-thread-safeguards`:

Safeguards against database access from different threads
---------------------------------------------------------
When using the database in a test with transaction rollback, you must ensure that
database access is only done from the same thread that the test is running on.

To avoid your fixtures/tests making changes outside the test thread, and as a result, the transaction, pytest-django
actively restricts where database connections may be opened:

- In async tests using ``db``: database access is only allowed from the single
thread used by ``SyncToAsync``. Using sync fixtures that touch the database in
an async test will raise::

RuntimeError: Database access is only allowed in an async context, modify your
test fixtures to be async or use the transactional_db fixture.

Fix by converting those fixtures to async (use ``pytest_asyncio.fixture``) and
using Django's async ORM methods (e.g. ``.acreate()``, ``.aget()``, ``.acount()``),
or by requesting ``transactional_db`` if you must keep sync fixtures.
See :ref:`async-db-behavior` for more details.

- In sync tests: database access is only allowed from the main thread. Attempting to use the database connection
from a different thread will raise::

RuntimeError: Database access is only allowed in the main thread, modify your
test fixtures to be sync or use the transactional_db fixture.

Fix this by ensuring all database transactions run in the main thread (e.g., avoiding the use of async fixtures),
or use ``transactional_db`` to allow mixing.


.. _`multi-db`:

Tests requiring multiple databases
Expand Down Expand Up @@ -524,3 +606,4 @@ Put this in ``conftest.py``::
django_db_blocker.unblock()
yield
django_db_blocker.restore()

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ coverage = [
"coverage[toml]",
"coverage-enable-subprocess",
]
async = [
"asgiref>=3.9.1",
"pytest-asyncio",
]
postgres = [
"psycopg[binary]",
]
Expand Down
196 changes: 158 additions & 38 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import os
from collections.abc import Generator, Iterable, Sequence
from collections.abc import AsyncGenerator, Generator, Iterable, Sequence
from contextlib import AbstractContextManager, contextmanager
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Union
Expand Down Expand Up @@ -201,8 +201,49 @@ def django_db_setup(
)


def _build_pytest_django_test_case(
test_case_class,
*,
reset_sequences: bool,
serialized_rollback: bool,
databases,
available_apps,
skip_django_testcase_class_setup: bool,
):
# Build a custom TestCase subclass with configured attributes and optional
# overrides to skip Django's TestCase class-level setup/teardown.
import django.test # local import to avoid hard dependency at import time

_reset_sequences = reset_sequences
_serialized_rollback = serialized_rollback
_databases = databases
_available_apps = available_apps

class PytestDjangoTestCase(test_case_class):
reset_sequences = _reset_sequences
serialized_rollback = _serialized_rollback
if _databases is not None:
databases = _databases
if _available_apps is not None:
available_apps = _available_apps

if skip_django_testcase_class_setup:

@classmethod
def setUpClass(cls) -> None:
# Skip django.test.TestCase.setUpClass, call its super instead
super(django.test.TestCase, cls).setUpClass()

@classmethod
def tearDownClass(cls) -> None:
# Skip django.test.TestCase.tearDownClass, call its super instead
super(django.test.TestCase, cls).tearDownClass()

return PytestDjangoTestCase


@pytest.fixture
def _django_db_helper(
def _sync_django_db_helper(
request: pytest.FixtureRequest,
django_db_setup: None,
django_db_blocker: DjangoDbBlocker,
Expand Down Expand Up @@ -239,7 +280,7 @@ def _django_db_helper(
"django_db_serialized_rollback" in request.fixturenames
)

with django_db_blocker.unblock():
with django_db_blocker.unblock(sync_only=not transactional):
import django.db
import django.test

Expand All @@ -248,41 +289,14 @@ def _django_db_helper(
else:
test_case_class = django.test.TestCase

_reset_sequences = reset_sequences
_serialized_rollback = serialized_rollback
_databases = databases
_available_apps = available_apps

class PytestDjangoTestCase(test_case_class): # type: ignore[misc,valid-type]
reset_sequences = _reset_sequences
serialized_rollback = _serialized_rollback
if _databases is not None:
databases = _databases
if _available_apps is not None:
available_apps = _available_apps

# For non-transactional tests, skip executing `django.test.TestCase`'s
# `setUpClass`/`tearDownClass`, only execute the super class ones.
#
# `TestCase`'s class setup manages the `setUpTestData`/class-level
# transaction functionality. We don't use it; instead we (will) offer
# our own alternatives. So it only adds overhead, and does some things
# which conflict with our (planned) functionality, particularly, it
# closes all database connections in `tearDownClass` which inhibits
# wrapping tests in higher-scoped transactions.
#
# It's possible a new version of Django will add some unrelated
# functionality to these methods, in which case skipping them completely
# would not be desirable. Let's cross that bridge when we get there...
if not transactional:

@classmethod
def setUpClass(cls) -> None:
super(django.test.TestCase, cls).setUpClass()

@classmethod
def tearDownClass(cls) -> None:
super(django.test.TestCase, cls).tearDownClass()
PytestDjangoTestCase = _build_pytest_django_test_case(
test_case_class,
reset_sequences=reset_sequences,
serialized_rollback=serialized_rollback,
databases=databases,
available_apps=available_apps,
skip_django_testcase_class_setup=(not transactional),
)

PytestDjangoTestCase.setUpClass()

Expand All @@ -298,6 +312,112 @@ def tearDownClass(cls) -> None:
PytestDjangoTestCase.doClassCleanups()


try:
import pytest_asyncio
except ImportError:

async def _async_django_db_helper(
request: pytest.FixtureRequest,
django_db_blocker: DjangoDbBlocker,
) -> AsyncGenerator[None, None]:
raise RuntimeError(
"The `pytest_asyncio` plugin is required to use the `async_django_db` fixture."
)
yield # pragma: no cover
else:

@pytest_asyncio.fixture
async def _async_django_db_helper(
request: pytest.FixtureRequest,
django_db_blocker: DjangoDbBlocker,
) -> AsyncGenerator[None, None]:
# same as _sync_django_db_helper, except for running the transaction start and rollback wrapped in a
# `sync_to_async` call
transactional, reset_sequences, databases, serialized_rollback, available_apps = (
_get_django_db_settings(request)
)

with django_db_blocker.unblock(async_only=True):
import django.db
import django.test

test_case_class = django.test.TestCase

PytestDjangoTestCase = _build_pytest_django_test_case(
test_case_class,
reset_sequences=reset_sequences,
serialized_rollback=serialized_rollback,
databases=databases,
available_apps=available_apps,
skip_django_testcase_class_setup=True,
)

from asgiref.sync import sync_to_async

await sync_to_async(PytestDjangoTestCase.setUpClass)()

test_case = PytestDjangoTestCase(methodName="__init__")
await sync_to_async(test_case._pre_setup, thread_sensitive=True)()

yield

await sync_to_async(test_case._post_teardown, thread_sensitive=True)()

await sync_to_async(PytestDjangoTestCase.tearDownClass)()

await sync_to_async(PytestDjangoTestCase.doClassCleanups)()


def _get_django_db_settings(request: pytest.FixtureRequest) -> _DjangoDb:
django_marker = request.node.get_closest_marker("django_db")
if django_marker:
(
transactional,
reset_sequences,
databases,
serialized_rollback,
available_apps,
) = validate_django_db(django_marker)
else:
(
transactional,
reset_sequences,
databases,
serialized_rollback,
available_apps,
) = False, False, None, False, None

transactional = (
transactional
or reset_sequences
or ("transactional_db" in request.fixturenames or "live_server" in request.fixturenames)
)

reset_sequences = reset_sequences or ("django_db_reset_sequences" in request.fixturenames)
serialized_rollback = serialized_rollback or (
"django_db_serialized_rollback" in request.fixturenames
)
return transactional, reset_sequences, databases, serialized_rollback, available_apps


@pytest.fixture
def _django_db_helper(
request: pytest.FixtureRequest,
django_db_setup: None,
django_db_blocker: DjangoDbBlocker,
):
asyncio_marker = request.node.get_closest_marker("asyncio")
transactional, *_ = _get_django_db_settings(request)
if transactional or not asyncio_marker:
# add the original sync fixture
request.getfixturevalue("_sync_django_db_helper")
else:
# add the async fixture. Will run it inside the event loop, which will cause the sync to async calls to
# start a transaction on the thread safe executor for that loop. This allows us to roll back orm calls made
# in that async test context.
request.getfixturevalue("_async_django_db_helper")


def _django_db_signature(
transaction: bool = False,
reset_sequences: bool = False,
Expand Down
Loading