Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
dist: xenial

services:
- postgresql

language: python

python:
Expand All @@ -12,7 +15,7 @@ env:
- DJANGO="Django==3.0.*"

install:
- pip install -U pip wheel setuptools
- pip install -U pip wheel setuptools psycopg2
- pip install $DJANGO -e .[tests]
- pip freeze

Expand Down
48 changes: 44 additions & 4 deletions channels/db.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,59 @@
from django.db import close_old_connections
import django
from django.db import connections

from asgiref.sync import SyncToAsync

HAS_INC_THREAD_SHARING = django.VERSION >= (2, 2)


def _close_old_connections():
"""Like django.db.close_old_connections, but skipping in_atomic_block.

Ref: https://code.djangoproject.com/ticket/30448
Ref: https://github.com/django/django/pull/11769
"""
for conn in connections.all():
if not conn.in_atomic_block:
conn.close_if_unusable_or_obsolete()


class DatabaseSyncToAsync(SyncToAsync):
"""
SyncToAsync version that cleans up old database connections when it exits.
SyncToAsync version that cleans up old database connections.
"""

def __init__(self, *args, **kwargs):
self.main_thread_connections = {name: connections[name] for name in connections}
super().__init__(*args, **kwargs)

def _inherit_main_thread_connections(self):
"""Copy/use DB connections in atomic block from main thread.

This is required for tests using Django's TestCase.
"""
restore_allow_thread_sharing = {}

for name in self.main_thread_connections:
if self.main_thread_connections[name].in_atomic_block:
connections[name] = self.main_thread_connections[name]
if HAS_INC_THREAD_SHARING:
connections[name].inc_thread_sharing()
else:
saved_sharing = connections[name].allow_thread_sharing
if not saved_sharing:
restore_allow_thread_sharing[name] = saved_sharing
connections[name].allow_thread_sharing = True
return restore_allow_thread_sharing

def thread_handler(self, loop, *args, **kwargs):
close_old_connections()
restore_allow_thread_sharing = self._inherit_main_thread_connections()
_close_old_connections()
try:
return super().thread_handler(loop, *args, **kwargs)
finally:
close_old_connections()
_close_old_connections()
for name, saved_sharing in restore_allow_thread_sharing.items():
connections[name].allow_thread_sharing = saved_sharing


# The class is TitleCased, but we want to encourage use as a callable/decorator
Expand Down
5 changes: 3 additions & 2 deletions channels/signals.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from django.db import close_old_connections
from django.dispatch import Signal

from channels.db import _close_old_connections

consumer_started = Signal(providing_args=["environ"])
consumer_finished = Signal()

# Connect connection closer to consumer finished as well
consumer_finished.connect(close_old_connections)
consumer_finished.connect(_close_old_connections)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ universal=1

[tool:pytest]
testpaths = tests
addopts = -ra
71 changes: 71 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest

pytest_plugins = ["pytester"]


@pytest.mark.django_db
@pytest.mark.asyncio
@pytest.mark.parametrize(
"db_engine", ("django.db.backends.sqlite3", "django.db.backends.postgresql")
)
@pytest.mark.parametrize("conn_max_age", (0, 600))
async def test_database_sync_to_async(db_engine, conn_max_age, testdir):
if db_engine == "django.db.backends.postgresql":
pytest.importorskip("psycopg2")

testdir.makeconftest(
"""
from django.conf import settings

settings.configure(
DATABASES={
"default": {
"ENGINE": %r,
"CONN_MAX_AGE": %d,
"NAME": "channels_tests",
}
},
INSTALLED_APPS=[
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"django.contrib.admin",
"channels",
],
)
"""
% (db_engine, conn_max_age)
)
p1 = testdir.makepyfile(
"""
import pytest
from django.contrib.auth.models import User

from channels.db import database_sync_to_async

@database_sync_to_async
def create_obj(**kwargs):
User.objects.create(**kwargs)

@pytest.mark.asyncio
@pytest.mark.django_db
async def test_inner():
from django.db import connections

conn = connections["default"]
assert conn.in_atomic_block

assert User.objects.count() == 0

await create_obj(username="alice")
await create_obj(username="bob")
assert User.objects.count() == 2

@pytest.mark.django_db
def test_check_rolled_back():
from django.contrib.auth.models import User
assert User.objects.count() == 0
"""
)
result = testdir.runpytest_subprocess(str(p1))
assert result.ret == 0