diff --git a/docs/database.rst b/docs/database.rst index fcdd219a..8f1511c8 100644 --- a/docs/database.rst +++ b/docs/database.rst @@ -60,6 +60,88 @@ select using an argument to the ``django_db`` mark:: def test_spam(): pass # test relying on transactions + +Async tests and database transactions +------------------------------------- + +``pytest-django`` supports async tests that use Django's async ORM APIs. +This requires the `pytest-asyncio `_ +plugin and marking your tests appropriately. + +Requirements +------------ + +- Install ``pytest-asyncio``. +- Mark async tests with both ``@pytest.mark.asyncio`` and + ``@pytest.mark.django_db`` (or request the ``db``/``transactional_db`` fixtures). + +Example (async ORM with transactional rollback per test):: + + import pytest + + @pytest.mark.asyncio + @pytest.mark.django_db + async def test_async_db_is_isolated(): + assert await Item.objects.acount() == 0 + await Item.objects.acreate(name="example") + assert await Item.objects.acount() == 1 + # changes are rolled back after the test + +.. _`async-db-behavior`: + +Behavior of ``db`` in async tests +--------------------------------- + +Tests using ``db`` wrap each test in a transaction and roll that transaction back at the end +(like ``django.test.TestCase``). In Django, transactions are bound to the database +connection, which is unique per thread. This means that all your database changes +must be made within the same thread to ensure they are rolled back before the next test. + +Django Async ORM calls, as of writing, use the ``asgiref.sync.sync_to_async`` +decorator to run the ORM calls on a dedicated thread executor. + +For async tests, pytest-django ensures the transaction +setup/teardown happens via ``asgiref.sync.sync_to_async``, which means the transaction is started & run on the +same thread on which async orm calls inside your test, like ``aget()`` are made. This ensures your test code +can safely modify the database using the async calls, as all its queries will be rolled back after the test. + +Tests using ``transactional_db`` flush the database between tests. This means that no matter in which thread +your test modifies the database, the changes will be removed after the test. This means you can avoid thinking +about sync/async database access if your test uses ``transactional_db``, at the cost of slower tests: +A flush is generally slower than rolling back a transaction. + +.. _`db-thread-safeguards`: + +Safeguards against database access from different threads +--------------------------------------------------------- +When using the database in a test with transaction rollback, you must ensure that +database access is only done from the same thread that the test is running on. + +To avoid your fixtures/tests making changes outside the test thread, and as a result, the transaction, pytest-django +actively restricts where database connections may be opened: + +- In async tests using ``db``: database access is only allowed from the single + thread used by ``SyncToAsync``. Using sync fixtures that touch the database in + an async test will raise:: + + RuntimeError: Database access is only allowed in an async context, modify your + test fixtures to be async or use the transactional_db fixture. + + Fix by converting those fixtures to async (use ``pytest_asyncio.fixture``) and + using Django's async ORM methods (e.g. ``.acreate()``, ``.aget()``, ``.acount()``), + or by requesting ``transactional_db`` if you must keep sync fixtures. + See :ref:`async-db-behavior` for more details. + +- In sync tests: database access is only allowed from the main thread. Attempting to use the database connection + from a different thread will raise:: + + RuntimeError: Database access is only allowed in the main thread, modify your + test fixtures to be sync or use the transactional_db fixture. + + Fix this by ensuring all database transactions run in the main thread (e.g., avoiding the use of async fixtures), + or use ``transactional_db`` to allow mixing. + + .. _`multi-db`: Tests requiring multiple databases @@ -524,3 +606,4 @@ Put this in ``conftest.py``:: django_db_blocker.unblock() yield django_db_blocker.restore() + diff --git a/pyproject.toml b/pyproject.toml index 883143af..16d204e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,10 @@ coverage = [ "coverage[toml]", "coverage-enable-subprocess", ] +async = [ + "asgiref>=3.9.1", + "pytest-asyncio", +] postgres = [ "psycopg[binary]", ] diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index 115dc4cc..b70d7f55 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from collections.abc import Generator, Iterable, Sequence +from collections.abc import AsyncGenerator, Generator, Iterable, Sequence from contextlib import AbstractContextManager, contextmanager from functools import partial from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Union @@ -201,8 +201,49 @@ def django_db_setup( ) +def _build_pytest_django_test_case( + test_case_class, + *, + reset_sequences: bool, + serialized_rollback: bool, + databases, + available_apps, + skip_django_testcase_class_setup: bool, +): + # Build a custom TestCase subclass with configured attributes and optional + # overrides to skip Django's TestCase class-level setup/teardown. + import django.test # local import to avoid hard dependency at import time + + _reset_sequences = reset_sequences + _serialized_rollback = serialized_rollback + _databases = databases + _available_apps = available_apps + + class PytestDjangoTestCase(test_case_class): + reset_sequences = _reset_sequences + serialized_rollback = _serialized_rollback + if _databases is not None: + databases = _databases + if _available_apps is not None: + available_apps = _available_apps + + if skip_django_testcase_class_setup: + + @classmethod + def setUpClass(cls) -> None: + # Skip django.test.TestCase.setUpClass, call its super instead + super(django.test.TestCase, cls).setUpClass() + + @classmethod + def tearDownClass(cls) -> None: + # Skip django.test.TestCase.tearDownClass, call its super instead + super(django.test.TestCase, cls).tearDownClass() + + return PytestDjangoTestCase + + @pytest.fixture -def _django_db_helper( +def _sync_django_db_helper( request: pytest.FixtureRequest, django_db_setup: None, django_db_blocker: DjangoDbBlocker, @@ -239,7 +280,7 @@ def _django_db_helper( "django_db_serialized_rollback" in request.fixturenames ) - with django_db_blocker.unblock(): + with django_db_blocker.unblock(sync_only=not transactional): import django.db import django.test @@ -248,41 +289,14 @@ def _django_db_helper( else: test_case_class = django.test.TestCase - _reset_sequences = reset_sequences - _serialized_rollback = serialized_rollback - _databases = databases - _available_apps = available_apps - - class PytestDjangoTestCase(test_case_class): # type: ignore[misc,valid-type] - reset_sequences = _reset_sequences - serialized_rollback = _serialized_rollback - if _databases is not None: - databases = _databases - if _available_apps is not None: - available_apps = _available_apps - - # For non-transactional tests, skip executing `django.test.TestCase`'s - # `setUpClass`/`tearDownClass`, only execute the super class ones. - # - # `TestCase`'s class setup manages the `setUpTestData`/class-level - # transaction functionality. We don't use it; instead we (will) offer - # our own alternatives. So it only adds overhead, and does some things - # which conflict with our (planned) functionality, particularly, it - # closes all database connections in `tearDownClass` which inhibits - # wrapping tests in higher-scoped transactions. - # - # It's possible a new version of Django will add some unrelated - # functionality to these methods, in which case skipping them completely - # would not be desirable. Let's cross that bridge when we get there... - if not transactional: - - @classmethod - def setUpClass(cls) -> None: - super(django.test.TestCase, cls).setUpClass() - - @classmethod - def tearDownClass(cls) -> None: - super(django.test.TestCase, cls).tearDownClass() + PytestDjangoTestCase = _build_pytest_django_test_case( + test_case_class, + reset_sequences=reset_sequences, + serialized_rollback=serialized_rollback, + databases=databases, + available_apps=available_apps, + skip_django_testcase_class_setup=(not transactional), + ) PytestDjangoTestCase.setUpClass() @@ -298,6 +312,112 @@ def tearDownClass(cls) -> None: PytestDjangoTestCase.doClassCleanups() +try: + import pytest_asyncio +except ImportError: + + async def _async_django_db_helper( + request: pytest.FixtureRequest, + django_db_blocker: DjangoDbBlocker, + ) -> AsyncGenerator[None, None]: + raise RuntimeError( + "The `pytest_asyncio` plugin is required to use the `async_django_db` fixture." + ) + yield # pragma: no cover +else: + + @pytest_asyncio.fixture + async def _async_django_db_helper( + request: pytest.FixtureRequest, + django_db_blocker: DjangoDbBlocker, + ) -> AsyncGenerator[None, None]: + # same as _sync_django_db_helper, except for running the transaction start and rollback wrapped in a + # `sync_to_async` call + transactional, reset_sequences, databases, serialized_rollback, available_apps = ( + _get_django_db_settings(request) + ) + + with django_db_blocker.unblock(async_only=True): + import django.db + import django.test + + test_case_class = django.test.TestCase + + PytestDjangoTestCase = _build_pytest_django_test_case( + test_case_class, + reset_sequences=reset_sequences, + serialized_rollback=serialized_rollback, + databases=databases, + available_apps=available_apps, + skip_django_testcase_class_setup=True, + ) + + from asgiref.sync import sync_to_async + + await sync_to_async(PytestDjangoTestCase.setUpClass)() + + test_case = PytestDjangoTestCase(methodName="__init__") + await sync_to_async(test_case._pre_setup, thread_sensitive=True)() + + yield + + await sync_to_async(test_case._post_teardown, thread_sensitive=True)() + + await sync_to_async(PytestDjangoTestCase.tearDownClass)() + + await sync_to_async(PytestDjangoTestCase.doClassCleanups)() + + +def _get_django_db_settings(request: pytest.FixtureRequest) -> _DjangoDb: + django_marker = request.node.get_closest_marker("django_db") + if django_marker: + ( + transactional, + reset_sequences, + databases, + serialized_rollback, + available_apps, + ) = validate_django_db(django_marker) + else: + ( + transactional, + reset_sequences, + databases, + serialized_rollback, + available_apps, + ) = False, False, None, False, None + + transactional = ( + transactional + or reset_sequences + or ("transactional_db" in request.fixturenames or "live_server" in request.fixturenames) + ) + + reset_sequences = reset_sequences or ("django_db_reset_sequences" in request.fixturenames) + serialized_rollback = serialized_rollback or ( + "django_db_serialized_rollback" in request.fixturenames + ) + return transactional, reset_sequences, databases, serialized_rollback, available_apps + + +@pytest.fixture +def _django_db_helper( + request: pytest.FixtureRequest, + django_db_setup: None, + django_db_blocker: DjangoDbBlocker, +): + asyncio_marker = request.node.get_closest_marker("asyncio") + transactional, *_ = _get_django_db_settings(request) + if transactional or not asyncio_marker: + # add the original sync fixture + request.getfixturevalue("_sync_django_db_helper") + else: + # add the async fixture. Will run it inside the event loop, which will cause the sync to async calls to + # start a transaction on the thread safe executor for that loop. This allows us to roll back orm calls made + # in that async test context. + request.getfixturevalue("_async_django_db_helper") + + def _django_db_signature( transaction: bool = False, reset_sequences: bool = False, diff --git a/pytest_django/plugin.py b/pytest_django/plugin.py index 9bab8971..89395ccd 100644 --- a/pytest_django/plugin.py +++ b/pytest_django/plugin.py @@ -11,18 +11,21 @@ import os import pathlib import sys +import threading import types from collections.abc import Generator from contextlib import AbstractContextManager from functools import reduce -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING, Any, Callable, NoReturn import pytest from .django_compat import is_django_unittest from .fixtures import ( + _async_django_db_helper, # noqa: F401 _django_db_helper, # noqa: F401 _live_server_helper, # noqa: F401 + _sync_django_db_helper, # noqa: F401 admin_client, # noqa: F401 admin_user, # noqa: F401 async_client, # noqa: F401 @@ -815,7 +818,7 @@ def __init__(self, *, _ispytest: bool = False) -> None: ) self._history = [] # type: ignore[var-annotated] - self._real_ensure_connection = None + self._real_ensure_connection: None | Callable[[Any], Any] = None @property def _dj_db_wrapper(self) -> django.db.backends.base.base.BaseDatabaseWrapper: @@ -831,18 +834,60 @@ def _dj_db_wrapper(self) -> django.db.backends.base.base.BaseDatabaseWrapper: def _save_active_wrapper(self) -> None: self._history.append(self._dj_db_wrapper.ensure_connection) - def _blocking_wrapper(*args, **kwargs) -> NoReturn: + def _blocking_wrapper(self, *args, **kwargs) -> NoReturn: __tracebackhide__ = True raise RuntimeError( "Database access not allowed, " 'use the "django_db" mark, or the ' - '"db" or "transactional_db" fixtures to enable it.' + '"db" or "transactional_db" fixtures to enable it. ' ) - def unblock(self) -> AbstractContextManager[None]: + def _unblocked_async_only(self, wrapper_self: Any, *args, **kwargs): + __tracebackhide__ = True + from asgiref.sync import SyncToAsync + + is_in_sync_to_async_thread = ( + next(iter(SyncToAsync.single_thread_executor._threads)) == threading.current_thread() + ) + if not is_in_sync_to_async_thread: + raise RuntimeError( + "Database access is only allowed in an async context, " + "modify your test fixtures to be async or use the transactional_db fixture." + "See https://pytest-django.readthedocs.io/en/latest/database.html#db-thread-safeguards for more information." + ) + elif self._real_ensure_connection is not None: + self._real_ensure_connection(wrapper_self, *args, **kwargs) + + def _unblocked_sync_only(self, wrapper_self: Any, *args, **kwargs): + __tracebackhide__ = True + if threading.current_thread() != threading.main_thread(): + raise RuntimeError( + "Database access is only allowed in the main thread, " + "modify your test fixtures to be sync or use the transactional_db fixture." + "See https://pytest-django.readthedocs.io/en/latest/database.html#db-thread-safeguards for more information." + ) + elif self._real_ensure_connection is not None: + self._real_ensure_connection(wrapper_self, *args, **kwargs) + + def unblock(self, sync_only=False, async_only=False) -> AbstractContextManager[None]: """Enable access to the Django database.""" + if sync_only and async_only: + raise ValueError("Cannot use both sync_only and async_only. Choose at most one.") self._save_active_wrapper() - self._dj_db_wrapper.ensure_connection = self._real_ensure_connection + if sync_only: + + def _method(wrapper_self, *args, **kwargs): + return self._unblocked_sync_only(wrapper_self, *args, **kwargs) + + self._dj_db_wrapper.ensure_connection = _method + elif async_only: + + def _method(wrapper_self, *args, **kwargs): + return self._unblocked_async_only(wrapper_self, *args, **kwargs) + + self._dj_db_wrapper.ensure_connection = _method + else: + self._dj_db_wrapper.ensure_connection = self._real_ensure_connection return _DatabaseBlockerContextManager(self) def block(self) -> AbstractContextManager[None]: diff --git a/tests/test_async_db.py b/tests/test_async_db.py new file mode 100644 index 00000000..be36c1e4 --- /dev/null +++ b/tests/test_async_db.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import Any, cast + +import pytest +from _pytest.mark import MarkDecorator + +from pytest_django_test.app.models import Item + + +try: + import pytest_asyncio +except ImportError: + pytestmark: MarkDecorator = pytest.mark.skip("pytest-asyncio is not installed") + fixturemark: MarkDecorator = pytest.mark.skip("pytest-asyncio is not installed") + +else: + pytestmark = pytest.mark.asyncio + fixturemark = cast(MarkDecorator, pytest_asyncio.fixture) + + +@pytest.mark.parametrize("run_number", [1, 2]) +@pytestmark +@pytest.mark.django_db +async def test_async_db(db, run_number) -> None: + # test async database usage remains isolated between tests + + assert await Item.objects.acount() == 0 + # make a new item instance, to be rolled back by the transaction wrapper before the next parametrized run + await Item.objects.acreate(name="blah") + assert await Item.objects.acount() == 1 + + +@fixturemark +async def db_item(db) -> Any: + return await Item.objects.acreate(name="async") + + +@pytest.fixture +def sync_db_item(db) -> Any: + return Item.objects.create(name="sync") + + +@pytestmark +@pytest.mark.xfail(strict=True, reason="Sync fixture used in async test") +async def test_db_item(db_item: Item, sync_db_item) -> None: + pass + + +@pytest.mark.xfail(strict=True, reason="Async fixture used in sync test") +def test_sync_db_item(async_db_item: Item, sync_db_item) -> None: + pass diff --git a/tests/test_db_thread_safeguards.py b/tests/test_db_thread_safeguards.py new file mode 100644 index 00000000..225e50dd --- /dev/null +++ b/tests/test_db_thread_safeguards.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import threading + +import pytest + +from pytest_django_test.app.models import Item + + +@pytest.mark.django_db +def test_sync_db_access_in_non_main_thread_is_blocked() -> None: + """ + Ensure that when using the sync django_db helper (non-transactional), + database access from a different thread raises the expected RuntimeError + stating that DB access is only allowed in the main thread. + + Mirrors the intent of the async equivalent test that checks thread + safeguards for async contexts. + """ + captured: list[BaseException | None] = [None] + + def worker() -> None: + try: + # Any ORM operation that touches the DB will attempt to ensure a connection. + # This should raise from the "sync_only" db blocker in non-main threads. + Item.objects.count() + except BaseException as exc: # noqa: BLE001 - we want to capture exactly what is raised + captured[0] = exc + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert captured[0] is not None, "Expected DB access in worker thread to raise an exception" + assert isinstance(captured[0], RuntimeError) + assert "only allowed in the main thread" in str(captured[0]) diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index 80578959..5227e434 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -752,6 +752,13 @@ def test_unblock_with_block(self, django_db_blocker: DjangoDbBlocker) -> None: with django_db_blocker.unblock(): Item.objects.exists() + def test_unblock_with_both_flags_raises_valueerror( + self, django_db_blocker: DjangoDbBlocker + ) -> None: + # When both sync_only and async_only are True, unblock should reject with ValueError + with pytest.raises(ValueError, match="Cannot use both sync_only and async_only"): + django_db_blocker.unblock(sync_only=True, async_only=True) + def test_mail(mailoutbox) -> None: assert mailoutbox is mail.outbox # check that mail.outbox and fixture value is same object diff --git a/tox.ini b/tox.ini index 1caaf78d..e173a57e 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,7 @@ envlist = [testenv] dependency_groups = testing + !dj42: async coverage: coverage mysql: mysql postgres: postgres @@ -43,7 +44,9 @@ commands = coverage: coverage xml [testenv:linting] -dependency_groups = linting +dependency_groups = + linting + async commands = ruff check --diff {posargs:pytest_django pytest_django_test tests} ruff format --quiet --diff {posargs:pytest_django pytest_django_test tests}