1+ import threading
12from functools import partial
23
3- from daphne .testing import DaphneProcess
44from django .contrib .staticfiles .handlers import ASGIStaticFilesHandler
5- from django .core .exceptions import ImproperlyConfigured
65from django .db import connections
76from django .db .backends .base .creation import TEST_DATABASE_PREFIX
87from django .test .testcases import TransactionTestCase
98from django .test .utils import modify_settings
9+ from django .utils .functional import classproperty
1010
1111from channels .routing import get_default_application
1212
@@ -28,65 +28,174 @@ def set_database_connection():
2828 settings .DATABASES ["default" ]["NAME" ] = test_db_name
2929
3030
31+ class ChannelsLiveServerThread (threading .Thread ):
32+ """Thread for running a live ASGI server while the tests are running."""
33+
34+ def __init__ (
35+ self , host , get_application , connections_override = None , port = 0 , setup = None
36+ ):
37+ self .host = host
38+ self .port = port
39+ self .get_application = get_application
40+ self .connections_override = connections_override
41+ self .setup = setup
42+ self .is_ready = threading .Event ()
43+ self .error = None
44+ super ().__init__ ()
45+
46+ def run (self ):
47+ """
48+ Set up the live server and databases, and then loop over handling
49+ ASGI requests.
50+ """
51+ if self .connections_override :
52+ # Override this thread's database connections with the ones
53+ # provided by the main thread.
54+ for alias , conn in self .connections_override .items ():
55+ connections [alias ] = conn
56+
57+ try :
58+ # Reinstall the reactor for this thread (same as DaphneProcess)
59+ from daphne .testing import _reinstall_reactor
60+
61+ _reinstall_reactor ()
62+
63+ from daphne .endpoints import build_endpoint_description_strings
64+ from daphne .server import Server
65+
66+ # Get the application
67+ application = self .get_application ()
68+
69+ # Create the server
70+ endpoints = build_endpoint_description_strings (
71+ host = self .host , port = self .port
72+ )
73+ self .server = Server (
74+ application = application ,
75+ endpoints = endpoints ,
76+ signal_handlers = False ,
77+ ready_callable = self ._set_ready ,
78+ verbosity = 0 ,
79+ )
80+
81+ # Run setup if provided
82+ if self .setup is not None :
83+ self .setup ()
84+
85+ # Start the server
86+ self .server .run ()
87+ except Exception as e :
88+ self .error = e
89+ self .is_ready .set ()
90+ finally :
91+ connections .close_all ()
92+
93+ def _set_ready (self ):
94+ """Called by Daphne when the server is ready."""
95+ if self .server .listening_addresses :
96+ self .port = self .server .listening_addresses [0 ][1 ]
97+ self .is_ready .set ()
98+
99+ def terminate (self ):
100+ if hasattr (self , "server" ):
101+ # Stop the ASGI server
102+ from twisted .internet import reactor
103+
104+ if reactor .running :
105+ reactor .callFromThread (reactor .stop )
106+ self .join (timeout = 5 )
107+
108+
31109class ChannelsLiveServerTestCase (TransactionTestCase ):
32110 """
33- Does basically the same as TransactionTestCase but also launches a
34- live Daphne server in a separate process, so
35- that the tests may use another test framework, such as Selenium,
36- instead of the built-in dummy client.
111+ Do basically the same as TransactionTestCase but also launch a live ASGI
112+ server in a separate thread so that the tests may use another testing
113+ framework, such as Selenium for example, instead of the built-in dummy
114+ client.
115+ It inherits from TransactionTestCase instead of TestCase because the
116+ threads don't share the same transactions (unless if using in-memory
117+ sqlite) and each thread needs to commit all their transactions so that the
118+ other thread can see the changes.
37119 """
38120
39121 host = "localhost"
40- ProtocolServerProcess = DaphneProcess
41- static_wrapper = ASGIStaticFilesHandler
122+ port = 0
123+ server_thread_class = ChannelsLiveServerThread
124+ static_handler = ASGIStaticFilesHandler
42125 serve_static = True
43126
44- @property
45- def live_server_url (self ):
46- return "http://%s:%s" % (self .host , self . _port )
127+ @classproperty
128+ def live_server_url (cls ):
129+ return "http://%s:%s" % (cls .host , cls . server_thread . port )
47130
48- @property
49- def live_server_ws_url (self ):
50- return "ws://%s:%s" % (self .host , self ._port )
131+ @classproperty
132+ def live_server_ws_url (cls ):
133+ return "ws://%s:%s" % (cls .host , cls .server_thread .port )
134+
135+ @classproperty
136+ def allowed_host (cls ):
137+ return cls .host
51138
52139 @classmethod
53- def setUpClass (cls ):
54- for connection in connections .all ():
55- if cls ._is_in_memory_db (connection ):
56- raise ImproperlyConfigured (
57- "ChannelLiveServerTestCase can not be used with in memory databases"
58- )
140+ def _make_connections_override (cls ):
141+ connections_override = {}
142+ for conn in connections .all ():
143+ # If using in-memory sqlite databases, pass the connections to
144+ # the server thread.
145+ if conn .vendor == "sqlite" and conn .is_in_memory_db ():
146+ connections_override [conn .alias ] = conn
147+ return connections_override
59148
149+ @classmethod
150+ def setUpClass (cls ):
60151 super ().setUpClass ()
61-
62- cls ._live_server_modified_settings = modify_settings (
63- ALLOWED_HOSTS = {"append" : cls .host }
152+ cls .enterClassContext (
153+ modify_settings (ALLOWED_HOSTS = {"append" : cls .allowed_host })
64154 )
65- cls ._live_server_modified_settings .enable ()
155+ cls ._start_server_thread ()
156+
157+ @classmethod
158+ def _start_server_thread (cls ):
159+ connections_override = cls ._make_connections_override ()
160+ for conn in connections_override .values ():
161+ # Explicitly enable thread-shareability for this connection.
162+ conn .inc_thread_sharing ()
163+
164+ cls .server_thread = cls ._create_server_thread (connections_override )
165+ cls .server_thread .daemon = True
166+ cls .server_thread .start ()
167+ cls .addClassCleanup (cls ._terminate_thread )
168+
169+ # Wait for the live server to be ready
170+ cls .server_thread .is_ready .wait ()
171+ if cls .server_thread .error :
172+ raise cls .server_thread .error
66173
174+ @classmethod
175+ def _create_server_thread (cls , connections_override ):
67176 get_application = partial (
68177 make_application ,
69- static_wrapper = cls .static_wrapper if cls .serve_static else None ,
178+ static_wrapper = cls .static_handler if cls .serve_static else None ,
70179 )
71- cls . _server_process = cls .ProtocolServerProcess (
180+ return cls .server_thread_class (
72181 cls .host ,
73182 get_application ,
183+ connections_override = connections_override ,
184+ port = cls .port ,
74185 setup = set_database_connection ,
75186 )
76- cls . _server_process . start ()
77- while True :
78- if not cls . _server_process . ready . wait ( timeout = 1 ):
79- if cls . _server_process . is_alive ():
80- continue
81- raise RuntimeError ( "Server stopped" ) from None
82- break
83- cls . _port = cls . _server_process . port . value
187+
188+ @ classmethod
189+ def _terminate_thread ( cls ):
190+ # Terminate the live server's thread.
191+ cls . server_thread . terminate ()
192+ # Restore shared connections' non-shareability.
193+ for conn in cls . server_thread . connections_override . values ():
194+ conn . dec_thread_sharing ()
84195
85196 @classmethod
86197 def tearDownClass (cls ):
87- cls ._server_process .terminate ()
88- cls ._server_process .join ()
89- cls ._live_server_modified_settings .disable ()
198+ # The cleanup is now handled by addClassCleanup in _start_server_thread
90199 super ().tearDownClass ()
91200
92201 @classmethod
0 commit comments