Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 167 additions & 39 deletions channels/testing/live.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
import threading
from functools import partial

from daphne.testing import DaphneProcess
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
from django.core.exceptions import ImproperlyConfigured
from django.db import connections
from django.db.backends.base.creation import TEST_DATABASE_PREFIX
from django.test.testcases import TransactionTestCase
from django.test.utils import modify_settings
from django.utils.functional import classproperty
from django.utils.version import PY311

from channels.routing import get_default_application

if not PY311:
# Backport of unittest.case._enter_context() from Python 3.11.
def _enter_context(cm, addcleanup):
# Look up the special methods on the type to match the with statement.
cls = type(cm)
try:
enter = cls.__enter__
exit = cls.__exit__
except AttributeError:
raise TypeError(
f"'{cls.__module__}.{cls.__qualname__}' object does not support the "
f"context manager protocol"
) from None
result = enter(cm)
addcleanup(exit, cm, None, None, None)
return result


def make_application(*, static_wrapper):
# Module-level function for pickle-ability
Expand All @@ -28,66 +46,176 @@ def set_database_connection():
settings.DATABASES["default"]["NAME"] = test_db_name


class ChannelsLiveServerThread(threading.Thread):
"""Thread for running a live ASGI server while the tests are running."""

def __init__(
self, host, get_application, connections_override=None, port=0, setup=None
):
self.host = host
self.port = port
self.get_application = get_application
self.connections_override = connections_override
self.setup = setup
self.is_ready = threading.Event()
self.error = None
super().__init__()

def run(self):
"""
Set up the live server and databases, and then loop over handling
ASGI requests.
"""
if self.connections_override:
# Override this thread's database connections with the ones
# provided by the main thread.
for alias, conn in self.connections_override.items():
connections[alias] = conn

try:
# Reinstall the reactor for this thread (same as DaphneProcess)
from daphne.testing import _reinstall_reactor

_reinstall_reactor()

from daphne.endpoints import build_endpoint_description_strings
from daphne.server import Server

# Get the application
application = self.get_application()

# Create the server
endpoints = build_endpoint_description_strings(
host=self.host, port=self.port
)
self.server = Server(
application=application,
endpoints=endpoints,
signal_handlers=False,
ready_callable=self._set_ready,
verbosity=0,
)

# Run setup if provided
if self.setup is not None:
self.setup()

# Start the server
self.server.run()
except Exception as e:
self.error = e
self.is_ready.set()
finally:
connections.close_all()

def _set_ready(self):
"""Called by Daphne when the server is ready."""
if self.server.listening_addresses:
self.port = self.server.listening_addresses[0][1]
self.is_ready.set()

def terminate(self):
if hasattr(self, "server"):
# Stop the ASGI server
from twisted.internet import reactor

if reactor.running:
reactor.callFromThread(reactor.stop)
self.join(timeout=5)


class ChannelsLiveServerTestCase(TransactionTestCase):
"""
Does basically the same as TransactionTestCase but also launches a
live Daphne server in a separate process, so
that the tests may use another test framework, such as Selenium,
instead of the built-in dummy client.
Do basically the same as TransactionTestCase but also launch a live ASGI
server in a separate thread so that the tests may use another testing
framework, such as Selenium for example, instead of the built-in dummy
client.
It inherits from TransactionTestCase instead of TestCase because the
threads don't share the same transactions (unless if using in-memory
sqlite) and each thread needs to commit all their transactions so that the
other thread can see the changes.
"""

host = "localhost"
ProtocolServerProcess = DaphneProcess
static_wrapper = ASGIStaticFilesHandler
port = 0
server_thread_class = ChannelsLiveServerThread
static_handler = ASGIStaticFilesHandler
serve_static = True

@property
def live_server_url(self):
return "http://%s:%s" % (self.host, self._port)
if not PY311:
# Backport of unittest.TestCase.enterClassContext() from Python 3.11.
@classmethod
def enterClassContext(cls, cm):
return _enter_context(cm, cls.addClassCleanup)

@classproperty
def live_server_url(cls):
return "http://%s:%s" % (cls.host, cls.server_thread.port)

@property
def live_server_ws_url(self):
return "ws://%s:%s" % (self.host, self._port)
@classproperty
def live_server_ws_url(cls):
return "ws://%s:%s" % (cls.host, cls.server_thread.port)

@classproperty
def allowed_host(cls):
return cls.host

@classmethod
def setUpClass(cls):
for connection in connections.all():
if cls._is_in_memory_db(connection):
raise ImproperlyConfigured(
"ChannelLiveServerTestCase can not be used with in memory databases"
)
def _make_connections_override(cls):
connections_override = {}
for conn in connections.all():
# If using in-memory sqlite databases, pass the connections to
# the server thread.
if conn.vendor == "sqlite" and conn.is_in_memory_db():
connections_override[conn.alias] = conn
return connections_override

@classmethod
def setUpClass(cls):
super().setUpClass()

cls._live_server_modified_settings = modify_settings(
ALLOWED_HOSTS={"append": cls.host}
cls.enterClassContext(
modify_settings(ALLOWED_HOSTS={"append": cls.allowed_host})
)
cls._live_server_modified_settings.enable()
cls._start_server_thread()

@classmethod
def _start_server_thread(cls):
connections_override = cls._make_connections_override()
for conn in connections_override.values():
# Explicitly enable thread-shareability for this connection.
conn.inc_thread_sharing()

cls.server_thread = cls._create_server_thread(connections_override)
cls.server_thread.daemon = True
cls.server_thread.start()
cls.addClassCleanup(cls._terminate_thread)

# Wait for the live server to be ready
cls.server_thread.is_ready.wait()
if cls.server_thread.error:
raise cls.server_thread.error

@classmethod
def _create_server_thread(cls, connections_override):
get_application = partial(
make_application,
static_wrapper=cls.static_wrapper if cls.serve_static else None,
static_wrapper=cls.static_handler if cls.serve_static else None,
)
cls._server_process = cls.ProtocolServerProcess(
return cls.server_thread_class(
cls.host,
get_application,
connections_override=connections_override,
port=cls.port,
setup=set_database_connection,
)
cls._server_process.start()
while True:
if not cls._server_process.ready.wait(timeout=1):
if cls._server_process.is_alive():
continue
raise RuntimeError("Server stopped") from None
break
cls._port = cls._server_process.port.value

@classmethod
def tearDownClass(cls):
cls._server_process.terminate()
cls._server_process.join()
cls._live_server_modified_settings.disable()
super().tearDownClass()
def _terminate_thread(cls):
# Terminate the live server's thread.
cls.server_thread.terminate()
# Restore shared connections' non-shareability.
for conn in cls.server_thread.connections_override.values():
conn.dec_thread_sharing()

@classmethod
def _is_in_memory_db(cls, connection):
Expand Down