Skip to content

Commit c6da185

Browse files
committed
wip
1 parent a135228 commit c6da185

File tree

1 file changed

+144
-37
lines changed

1 file changed

+144
-37
lines changed

channels/testing/live.py

Lines changed: 144 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1+
import threading
12
from functools import partial
23

3-
from daphne.testing import DaphneProcess
44
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
5-
from django.core.exceptions import ImproperlyConfigured
65
from django.db import connections
76
from django.db.backends.base.creation import TEST_DATABASE_PREFIX
87
from django.test.testcases import TransactionTestCase
98
from django.test.utils import modify_settings
9+
from django.utils.functional import classproperty
1010

1111
from channels.routing import get_default_application
1212

@@ -28,65 +28,172 @@ def set_database_connection():
2828
settings.DATABASES["default"]["NAME"] = test_db_name
2929

3030

31+
class ChannelsLiveServerThread(threading.Thread):
32+
"""Thread for running a live ASGI server while the tests are running."""
33+
34+
def __init__(self, host, get_application, connections_override=None, port=0, setup=None):
35+
self.host = host
36+
self.port = port
37+
self.get_application = get_application
38+
self.connections_override = connections_override
39+
self.setup = setup
40+
self.is_ready = threading.Event()
41+
self.error = None
42+
super().__init__()
43+
44+
def run(self):
45+
"""
46+
Set up the live server and databases, and then loop over handling
47+
ASGI requests.
48+
"""
49+
if self.connections_override:
50+
# Override this thread's database connections with the ones
51+
# provided by the main thread.
52+
for alias, conn in self.connections_override.items():
53+
connections[alias] = conn
54+
55+
try:
56+
# Reinstall the reactor for this thread (same as DaphneProcess)
57+
from daphne.testing import _reinstall_reactor
58+
_reinstall_reactor()
59+
60+
from twisted.internet import reactor
61+
from daphne.endpoints import build_endpoint_description_strings
62+
from daphne.server import Server
63+
64+
# Get the application
65+
application = self.get_application()
66+
67+
# Create the server
68+
endpoints = build_endpoint_description_strings(
69+
host=self.host, port=self.port
70+
)
71+
self.server = Server(
72+
application=application,
73+
endpoints=endpoints,
74+
signal_handlers=False,
75+
ready_callable=self._set_ready,
76+
verbosity=0,
77+
)
78+
79+
# Run setup if provided
80+
if self.setup is not None:
81+
self.setup()
82+
83+
# Start the server
84+
self.server.run()
85+
except Exception as e:
86+
self.error = e
87+
self.is_ready.set()
88+
finally:
89+
connections.close_all()
90+
91+
def _set_ready(self):
92+
"""Called by Daphne when the server is ready."""
93+
if self.server.listening_addresses:
94+
self.port = self.server.listening_addresses[0][1]
95+
self.is_ready.set()
96+
97+
def terminate(self):
98+
if hasattr(self, "server"):
99+
# Stop the ASGI server
100+
from twisted.internet import reactor
101+
102+
if reactor.running:
103+
reactor.callFromThread(reactor.stop)
104+
self.join(timeout=5)
105+
106+
31107
class ChannelsLiveServerTestCase(TransactionTestCase):
32108
"""
33-
Does basically the same as TransactionTestCase but also launches a
34-
live Daphne server in a separate process, so
35-
that the tests may use another test framework, such as Selenium,
36-
instead of the built-in dummy client.
109+
Do basically the same as TransactionTestCase but also launch a live ASGI
110+
server in a separate thread so that the tests may use another testing
111+
framework, such as Selenium for example, instead of the built-in dummy
112+
client.
113+
It inherits from TransactionTestCase instead of TestCase because the
114+
threads don't share the same transactions (unless if using in-memory
115+
sqlite) and each thread needs to commit all their transactions so that the
116+
other thread can see the changes.
37117
"""
38118

39119
host = "localhost"
40-
ProtocolServerProcess = DaphneProcess
41-
static_wrapper = ASGIStaticFilesHandler
120+
port = 0
121+
server_thread_class = ChannelsLiveServerThread
122+
static_handler = ASGIStaticFilesHandler
42123
serve_static = True
43124

44-
@property
45-
def live_server_url(self):
46-
return "http://%s:%s" % (self.host, self._port)
125+
@classproperty
126+
def live_server_url(cls):
127+
return "http://%s:%s" % (cls.host, cls.server_thread.port)
47128

48-
@property
49-
def live_server_ws_url(self):
50-
return "ws://%s:%s" % (self.host, self._port)
129+
@classproperty
130+
def live_server_ws_url(cls):
131+
return "ws://%s:%s" % (cls.host, cls.server_thread.port)
132+
133+
@classproperty
134+
def allowed_host(cls):
135+
return cls.host
51136

52137
@classmethod
53-
def setUpClass(cls):
54-
for connection in connections.all():
55-
if cls._is_in_memory_db(connection):
56-
raise ImproperlyConfigured(
57-
"ChannelLiveServerTestCase can not be used with in memory databases"
58-
)
138+
def _make_connections_override(cls):
139+
connections_override = {}
140+
for conn in connections.all():
141+
# If using in-memory sqlite databases, pass the connections to
142+
# the server thread.
143+
if conn.vendor == "sqlite" and conn.is_in_memory_db():
144+
connections_override[conn.alias] = conn
145+
return connections_override
59146

147+
@classmethod
148+
def setUpClass(cls):
60149
super().setUpClass()
61-
62-
cls._live_server_modified_settings = modify_settings(
63-
ALLOWED_HOSTS={"append": cls.host}
150+
cls.enterClassContext(
151+
modify_settings(ALLOWED_HOSTS={"append": cls.allowed_host})
64152
)
65-
cls._live_server_modified_settings.enable()
153+
cls._start_server_thread()
154+
155+
@classmethod
156+
def _start_server_thread(cls):
157+
connections_override = cls._make_connections_override()
158+
for conn in connections_override.values():
159+
# Explicitly enable thread-shareability for this connection.
160+
conn.inc_thread_sharing()
161+
162+
cls.server_thread = cls._create_server_thread(connections_override)
163+
cls.server_thread.daemon = True
164+
cls.server_thread.start()
165+
cls.addClassCleanup(cls._terminate_thread)
166+
167+
# Wait for the live server to be ready
168+
cls.server_thread.is_ready.wait()
169+
if cls.server_thread.error:
170+
raise cls.server_thread.error
66171

172+
@classmethod
173+
def _create_server_thread(cls, connections_override):
67174
get_application = partial(
68175
make_application,
69-
static_wrapper=cls.static_wrapper if cls.serve_static else None,
176+
static_wrapper=cls.static_handler if cls.serve_static else None,
70177
)
71-
cls._server_process = cls.ProtocolServerProcess(
178+
return cls.server_thread_class(
72179
cls.host,
73180
get_application,
181+
connections_override=connections_override,
182+
port=cls.port,
74183
setup=set_database_connection,
75184
)
76-
cls._server_process.start()
77-
while True:
78-
if not cls._server_process.ready.wait(timeout=1):
79-
if cls._server_process.is_alive():
80-
continue
81-
raise RuntimeError("Server stopped") from None
82-
break
83-
cls._port = cls._server_process.port.value
185+
186+
@classmethod
187+
def _terminate_thread(cls):
188+
# Terminate the live server's thread.
189+
cls.server_thread.terminate()
190+
# Restore shared connections' non-shareability.
191+
for conn in cls.server_thread.connections_override.values():
192+
conn.dec_thread_sharing()
84193

85194
@classmethod
86195
def tearDownClass(cls):
87-
cls._server_process.terminate()
88-
cls._server_process.join()
89-
cls._live_server_modified_settings.disable()
196+
# The cleanup is now handled by addClassCleanup in _start_server_thread
90197
super().tearDownClass()
91198

92199
@classmethod

0 commit comments

Comments
 (0)