diff --git a/channels/testing/live.py b/channels/testing/live.py index aa1a7880..f343af5b 100644 --- a/channels/testing/live.py +++ b/channels/testing/live.py @@ -1,15 +1,33 @@ +import threading from functools import partial -from daphne.testing import DaphneProcess from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler -from django.core.exceptions import ImproperlyConfigured from django.db import connections from django.db.backends.base.creation import TEST_DATABASE_PREFIX from django.test.testcases import TransactionTestCase from django.test.utils import modify_settings +from django.utils.functional import classproperty +from django.utils.version import PY311 from channels.routing import get_default_application +if not PY311: + # Backport of unittest.case._enter_context() from Python 3.11. + def _enter_context(cm, addcleanup): + # Look up the special methods on the type to match the with statement. + cls = type(cm) + try: + enter = cls.__enter__ + exit = cls.__exit__ + except AttributeError: + raise TypeError( + f"'{cls.__module__}.{cls.__qualname__}' object does not support the " + f"context manager protocol" + ) from None + result = enter(cm) + addcleanup(exit, cm, None, None, None) + return result + def make_application(*, static_wrapper): # Module-level function for pickle-ability @@ -28,66 +46,176 @@ def set_database_connection(): settings.DATABASES["default"]["NAME"] = test_db_name +class ChannelsLiveServerThread(threading.Thread): + """Thread for running a live ASGI server while the tests are running.""" + + def __init__( + self, host, get_application, connections_override=None, port=0, setup=None + ): + self.host = host + self.port = port + self.get_application = get_application + self.connections_override = connections_override + self.setup = setup + self.is_ready = threading.Event() + self.error = None + super().__init__() + + def run(self): + """ + Set up the live server and databases, and then loop over handling + ASGI requests. + """ + if self.connections_override: + # Override this thread's database connections with the ones + # provided by the main thread. + for alias, conn in self.connections_override.items(): + connections[alias] = conn + + try: + # Reinstall the reactor for this thread (same as DaphneProcess) + from daphne.testing import _reinstall_reactor + + _reinstall_reactor() + + from daphne.endpoints import build_endpoint_description_strings + from daphne.server import Server + + # Get the application + application = self.get_application() + + # Create the server + endpoints = build_endpoint_description_strings( + host=self.host, port=self.port + ) + self.server = Server( + application=application, + endpoints=endpoints, + signal_handlers=False, + ready_callable=self._set_ready, + verbosity=0, + ) + + # Run setup if provided + if self.setup is not None: + self.setup() + + # Start the server + self.server.run() + except Exception as e: + self.error = e + self.is_ready.set() + finally: + connections.close_all() + + def _set_ready(self): + """Called by Daphne when the server is ready.""" + if self.server.listening_addresses: + self.port = self.server.listening_addresses[0][1] + self.is_ready.set() + + def terminate(self): + if hasattr(self, "server"): + # Stop the ASGI server + from twisted.internet import reactor + + if reactor.running: + reactor.callFromThread(reactor.stop) + self.join(timeout=5) + + class ChannelsLiveServerTestCase(TransactionTestCase): """ - Does basically the same as TransactionTestCase but also launches a - live Daphne server in a separate process, so - that the tests may use another test framework, such as Selenium, - instead of the built-in dummy client. + Do basically the same as TransactionTestCase but also launch a live ASGI + server in a separate thread so that the tests may use another testing + framework, such as Selenium for example, instead of the built-in dummy + client. + It inherits from TransactionTestCase instead of TestCase because the + threads don't share the same transactions (unless if using in-memory + sqlite) and each thread needs to commit all their transactions so that the + other thread can see the changes. """ host = "localhost" - ProtocolServerProcess = DaphneProcess - static_wrapper = ASGIStaticFilesHandler + port = 0 + server_thread_class = ChannelsLiveServerThread + static_handler = ASGIStaticFilesHandler serve_static = True - @property - def live_server_url(self): - return "http://%s:%s" % (self.host, self._port) + if not PY311: + # Backport of unittest.TestCase.enterClassContext() from Python 3.11. + @classmethod + def enterClassContext(cls, cm): + return _enter_context(cm, cls.addClassCleanup) + + @classproperty + def live_server_url(cls): + return "http://%s:%s" % (cls.host, cls.server_thread.port) - @property - def live_server_ws_url(self): - return "ws://%s:%s" % (self.host, self._port) + @classproperty + def live_server_ws_url(cls): + return "ws://%s:%s" % (cls.host, cls.server_thread.port) + + @classproperty + def allowed_host(cls): + return cls.host @classmethod - def setUpClass(cls): - for connection in connections.all(): - if cls._is_in_memory_db(connection): - raise ImproperlyConfigured( - "ChannelLiveServerTestCase can not be used with in memory databases" - ) + def _make_connections_override(cls): + connections_override = {} + for conn in connections.all(): + # If using in-memory sqlite databases, pass the connections to + # the server thread. + if conn.vendor == "sqlite" and conn.is_in_memory_db(): + connections_override[conn.alias] = conn + return connections_override + @classmethod + def setUpClass(cls): super().setUpClass() - - cls._live_server_modified_settings = modify_settings( - ALLOWED_HOSTS={"append": cls.host} + cls.enterClassContext( + modify_settings(ALLOWED_HOSTS={"append": cls.allowed_host}) ) - cls._live_server_modified_settings.enable() + cls._start_server_thread() + + @classmethod + def _start_server_thread(cls): + connections_override = cls._make_connections_override() + for conn in connections_override.values(): + # Explicitly enable thread-shareability for this connection. + conn.inc_thread_sharing() + + cls.server_thread = cls._create_server_thread(connections_override) + cls.server_thread.daemon = True + cls.server_thread.start() + cls.addClassCleanup(cls._terminate_thread) + + # Wait for the live server to be ready + cls.server_thread.is_ready.wait() + if cls.server_thread.error: + raise cls.server_thread.error + @classmethod + def _create_server_thread(cls, connections_override): get_application = partial( make_application, - static_wrapper=cls.static_wrapper if cls.serve_static else None, + static_wrapper=cls.static_handler if cls.serve_static else None, ) - cls._server_process = cls.ProtocolServerProcess( + return cls.server_thread_class( cls.host, get_application, + connections_override=connections_override, + port=cls.port, setup=set_database_connection, ) - cls._server_process.start() - while True: - if not cls._server_process.ready.wait(timeout=1): - if cls._server_process.is_alive(): - continue - raise RuntimeError("Server stopped") from None - break - cls._port = cls._server_process.port.value @classmethod - def tearDownClass(cls): - cls._server_process.terminate() - cls._server_process.join() - cls._live_server_modified_settings.disable() - super().tearDownClass() + def _terminate_thread(cls): + # Terminate the live server's thread. + cls.server_thread.terminate() + # Restore shared connections' non-shareability. + for conn in cls.server_thread.connections_override.values(): + conn.dec_thread_sharing() @classmethod def _is_in_memory_db(cls, connection):