diff --git a/.travis.yml b/.travis.yml index 6e78e0be5..d9921f749 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,8 @@ dist: xenial +services: + - postgresql + language: python python: @@ -12,7 +15,7 @@ env: - DJANGO="Django==3.0.*" install: - - pip install -U pip wheel setuptools + - pip install -U pip wheel setuptools psycopg2 - pip install $DJANGO -e .[tests] - pip freeze diff --git a/channels/db.py b/channels/db.py index 13383fd11..0f13b07bd 100644 --- a/channels/db.py +++ b/channels/db.py @@ -1,19 +1,59 @@ -from django.db import close_old_connections +import django +from django.db import connections from asgiref.sync import SyncToAsync +HAS_INC_THREAD_SHARING = django.VERSION >= (2, 2) + + +def _close_old_connections(): + """Like django.db.close_old_connections, but skipping in_atomic_block. + + Ref: https://code.djangoproject.com/ticket/30448 + Ref: https://github.com/django/django/pull/11769 + """ + for conn in connections.all(): + if not conn.in_atomic_block: + conn.close_if_unusable_or_obsolete() + class DatabaseSyncToAsync(SyncToAsync): """ - SyncToAsync version that cleans up old database connections when it exits. + SyncToAsync version that cleans up old database connections. """ + def __init__(self, *args, **kwargs): + self.main_thread_connections = {name: connections[name] for name in connections} + super().__init__(*args, **kwargs) + + def _inherit_main_thread_connections(self): + """Copy/use DB connections in atomic block from main thread. + + This is required for tests using Django's TestCase. + """ + restore_allow_thread_sharing = {} + + for name in self.main_thread_connections: + if self.main_thread_connections[name].in_atomic_block: + connections[name] = self.main_thread_connections[name] + if HAS_INC_THREAD_SHARING: + connections[name].inc_thread_sharing() + else: + saved_sharing = connections[name].allow_thread_sharing + if not saved_sharing: + restore_allow_thread_sharing[name] = saved_sharing + connections[name].allow_thread_sharing = True + return restore_allow_thread_sharing + def thread_handler(self, loop, *args, **kwargs): - close_old_connections() + restore_allow_thread_sharing = self._inherit_main_thread_connections() + _close_old_connections() try: return super().thread_handler(loop, *args, **kwargs) finally: - close_old_connections() + _close_old_connections() + for name, saved_sharing in restore_allow_thread_sharing.items(): + connections[name].allow_thread_sharing = saved_sharing # The class is TitleCased, but we want to encourage use as a callable/decorator diff --git a/channels/signals.py b/channels/signals.py index 4899318e5..c81363673 100644 --- a/channels/signals.py +++ b/channels/signals.py @@ -1,8 +1,9 @@ -from django.db import close_old_connections from django.dispatch import Signal +from channels.db import _close_old_connections + consumer_started = Signal(providing_args=["environ"]) consumer_finished = Signal() # Connect connection closer to consumer finished as well -consumer_finished.connect(close_old_connections) +consumer_finished.connect(_close_old_connections) diff --git a/setup.cfg b/setup.cfg index ed37d52ea..83dfda0eb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,3 +15,4 @@ universal=1 [tool:pytest] testpaths = tests +addopts = -ra diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 000000000..f996aefe8 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,71 @@ +import pytest + +pytest_plugins = ["pytester"] + + +@pytest.mark.django_db +@pytest.mark.asyncio +@pytest.mark.parametrize( + "db_engine", ("django.db.backends.sqlite3", "django.db.backends.postgresql") +) +@pytest.mark.parametrize("conn_max_age", (0, 600)) +async def test_database_sync_to_async(db_engine, conn_max_age, testdir): + if db_engine == "django.db.backends.postgresql": + pytest.importorskip("psycopg2") + + testdir.makeconftest( + """ + from django.conf import settings + + settings.configure( + DATABASES={ + "default": { + "ENGINE": %r, + "CONN_MAX_AGE": %d, + "NAME": "channels_tests", + } + }, + INSTALLED_APPS=[ + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.admin", + "channels", + ], + ) + """ + % (db_engine, conn_max_age) + ) + p1 = testdir.makepyfile( + """ + import pytest + from django.contrib.auth.models import User + + from channels.db import database_sync_to_async + + @database_sync_to_async + def create_obj(**kwargs): + User.objects.create(**kwargs) + + @pytest.mark.asyncio + @pytest.mark.django_db + async def test_inner(): + from django.db import connections + + conn = connections["default"] + assert conn.in_atomic_block + + assert User.objects.count() == 0 + + await create_obj(username="alice") + await create_obj(username="bob") + assert User.objects.count() == 2 + + @pytest.mark.django_db + def test_check_rolled_back(): + from django.contrib.auth.models import User + assert User.objects.count() == 0 + """ + ) + result = testdir.runpytest_subprocess(str(p1)) + assert result.ret == 0