Skip to content

Commit 446be96

Browse files
committed
wip
1 parent a135228 commit 446be96

File tree

1 file changed

+146
-37
lines changed

1 file changed

+146
-37
lines changed

channels/testing/live.py

Lines changed: 146 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1+
import threading
12
from functools import partial
23

3-
from daphne.testing import DaphneProcess
44
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
5-
from django.core.exceptions import ImproperlyConfigured
65
from django.db import connections
76
from django.db.backends.base.creation import TEST_DATABASE_PREFIX
87
from django.test.testcases import TransactionTestCase
98
from django.test.utils import modify_settings
9+
from django.utils.functional import classproperty
1010

1111
from channels.routing import get_default_application
1212

@@ -28,65 +28,174 @@ def set_database_connection():
2828
settings.DATABASES["default"]["NAME"] = test_db_name
2929

3030

31+
class ChannelsLiveServerThread(threading.Thread):
32+
"""Thread for running a live ASGI server while the tests are running."""
33+
34+
def __init__(
35+
self, host, get_application, connections_override=None, port=0, setup=None
36+
):
37+
self.host = host
38+
self.port = port
39+
self.get_application = get_application
40+
self.connections_override = connections_override
41+
self.setup = setup
42+
self.is_ready = threading.Event()
43+
self.error = None
44+
super().__init__()
45+
46+
def run(self):
47+
"""
48+
Set up the live server and databases, and then loop over handling
49+
ASGI requests.
50+
"""
51+
if self.connections_override:
52+
# Override this thread's database connections with the ones
53+
# provided by the main thread.
54+
for alias, conn in self.connections_override.items():
55+
connections[alias] = conn
56+
57+
try:
58+
# Reinstall the reactor for this thread (same as DaphneProcess)
59+
from daphne.testing import _reinstall_reactor
60+
61+
_reinstall_reactor()
62+
63+
from daphne.endpoints import build_endpoint_description_strings
64+
from daphne.server import Server
65+
66+
# Get the application
67+
application = self.get_application()
68+
69+
# Create the server
70+
endpoints = build_endpoint_description_strings(
71+
host=self.host, port=self.port
72+
)
73+
self.server = Server(
74+
application=application,
75+
endpoints=endpoints,
76+
signal_handlers=False,
77+
ready_callable=self._set_ready,
78+
verbosity=0,
79+
)
80+
81+
# Run setup if provided
82+
if self.setup is not None:
83+
self.setup()
84+
85+
# Start the server
86+
self.server.run()
87+
except Exception as e:
88+
self.error = e
89+
self.is_ready.set()
90+
finally:
91+
connections.close_all()
92+
93+
def _set_ready(self):
94+
"""Called by Daphne when the server is ready."""
95+
if self.server.listening_addresses:
96+
self.port = self.server.listening_addresses[0][1]
97+
self.is_ready.set()
98+
99+
def terminate(self):
100+
if hasattr(self, "server"):
101+
# Stop the ASGI server
102+
from twisted.internet import reactor
103+
104+
if reactor.running:
105+
reactor.callFromThread(reactor.stop)
106+
self.join(timeout=5)
107+
108+
31109
class ChannelsLiveServerTestCase(TransactionTestCase):
32110
"""
33-
Does basically the same as TransactionTestCase but also launches a
34-
live Daphne server in a separate process, so
35-
that the tests may use another test framework, such as Selenium,
36-
instead of the built-in dummy client.
111+
Do basically the same as TransactionTestCase but also launch a live ASGI
112+
server in a separate thread so that the tests may use another testing
113+
framework, such as Selenium for example, instead of the built-in dummy
114+
client.
115+
It inherits from TransactionTestCase instead of TestCase because the
116+
threads don't share the same transactions (unless if using in-memory
117+
sqlite) and each thread needs to commit all their transactions so that the
118+
other thread can see the changes.
37119
"""
38120

39121
host = "localhost"
40-
ProtocolServerProcess = DaphneProcess
41-
static_wrapper = ASGIStaticFilesHandler
122+
port = 0
123+
server_thread_class = ChannelsLiveServerThread
124+
static_handler = ASGIStaticFilesHandler
42125
serve_static = True
43126

44-
@property
45-
def live_server_url(self):
46-
return "http://%s:%s" % (self.host, self._port)
127+
@classproperty
128+
def live_server_url(cls):
129+
return "http://%s:%s" % (cls.host, cls.server_thread.port)
47130

48-
@property
49-
def live_server_ws_url(self):
50-
return "ws://%s:%s" % (self.host, self._port)
131+
@classproperty
132+
def live_server_ws_url(cls):
133+
return "ws://%s:%s" % (cls.host, cls.server_thread.port)
134+
135+
@classproperty
136+
def allowed_host(cls):
137+
return cls.host
51138

52139
@classmethod
53-
def setUpClass(cls):
54-
for connection in connections.all():
55-
if cls._is_in_memory_db(connection):
56-
raise ImproperlyConfigured(
57-
"ChannelLiveServerTestCase can not be used with in memory databases"
58-
)
140+
def _make_connections_override(cls):
141+
connections_override = {}
142+
for conn in connections.all():
143+
# If using in-memory sqlite databases, pass the connections to
144+
# the server thread.
145+
if conn.vendor == "sqlite" and conn.is_in_memory_db():
146+
connections_override[conn.alias] = conn
147+
return connections_override
59148

149+
@classmethod
150+
def setUpClass(cls):
60151
super().setUpClass()
61-
62-
cls._live_server_modified_settings = modify_settings(
63-
ALLOWED_HOSTS={"append": cls.host}
152+
cls.enterClassContext(
153+
modify_settings(ALLOWED_HOSTS={"append": cls.allowed_host})
64154
)
65-
cls._live_server_modified_settings.enable()
155+
cls._start_server_thread()
156+
157+
@classmethod
158+
def _start_server_thread(cls):
159+
connections_override = cls._make_connections_override()
160+
for conn in connections_override.values():
161+
# Explicitly enable thread-shareability for this connection.
162+
conn.inc_thread_sharing()
163+
164+
cls.server_thread = cls._create_server_thread(connections_override)
165+
cls.server_thread.daemon = True
166+
cls.server_thread.start()
167+
cls.addClassCleanup(cls._terminate_thread)
168+
169+
# Wait for the live server to be ready
170+
cls.server_thread.is_ready.wait()
171+
if cls.server_thread.error:
172+
raise cls.server_thread.error
66173

174+
@classmethod
175+
def _create_server_thread(cls, connections_override):
67176
get_application = partial(
68177
make_application,
69-
static_wrapper=cls.static_wrapper if cls.serve_static else None,
178+
static_wrapper=cls.static_handler if cls.serve_static else None,
70179
)
71-
cls._server_process = cls.ProtocolServerProcess(
180+
return cls.server_thread_class(
72181
cls.host,
73182
get_application,
183+
connections_override=connections_override,
184+
port=cls.port,
74185
setup=set_database_connection,
75186
)
76-
cls._server_process.start()
77-
while True:
78-
if not cls._server_process.ready.wait(timeout=1):
79-
if cls._server_process.is_alive():
80-
continue
81-
raise RuntimeError("Server stopped") from None
82-
break
83-
cls._port = cls._server_process.port.value
187+
188+
@classmethod
189+
def _terminate_thread(cls):
190+
# Terminate the live server's thread.
191+
cls.server_thread.terminate()
192+
# Restore shared connections' non-shareability.
193+
for conn in cls.server_thread.connections_override.values():
194+
conn.dec_thread_sharing()
84195

85196
@classmethod
86197
def tearDownClass(cls):
87-
cls._server_process.terminate()
88-
cls._server_process.join()
89-
cls._live_server_modified_settings.disable()
198+
# The cleanup is now handled by addClassCleanup in _start_server_thread
90199
super().tearDownClass()
91200

92201
@classmethod

0 commit comments

Comments
 (0)