1
+ import threading
1
2
from functools import partial
2
3
3
- from daphne .testing import DaphneProcess
4
4
from django .contrib .staticfiles .handlers import ASGIStaticFilesHandler
5
- from django .core .exceptions import ImproperlyConfigured
6
5
from django .db import connections
7
6
from django .db .backends .base .creation import TEST_DATABASE_PREFIX
8
7
from django .test .testcases import TransactionTestCase
9
8
from django .test .utils import modify_settings
9
+ from django .utils .functional import classproperty
10
10
11
11
from channels .routing import get_default_application
12
12
@@ -28,65 +28,174 @@ def set_database_connection():
28
28
settings .DATABASES ["default" ]["NAME" ] = test_db_name
29
29
30
30
31
+ class ChannelsLiveServerThread (threading .Thread ):
32
+ """Thread for running a live ASGI server while the tests are running."""
33
+
34
+ def __init__ (
35
+ self , host , get_application , connections_override = None , port = 0 , setup = None
36
+ ):
37
+ self .host = host
38
+ self .port = port
39
+ self .get_application = get_application
40
+ self .connections_override = connections_override
41
+ self .setup = setup
42
+ self .is_ready = threading .Event ()
43
+ self .error = None
44
+ super ().__init__ ()
45
+
46
+ def run (self ):
47
+ """
48
+ Set up the live server and databases, and then loop over handling
49
+ ASGI requests.
50
+ """
51
+ if self .connections_override :
52
+ # Override this thread's database connections with the ones
53
+ # provided by the main thread.
54
+ for alias , conn in self .connections_override .items ():
55
+ connections [alias ] = conn
56
+
57
+ try :
58
+ # Reinstall the reactor for this thread (same as DaphneProcess)
59
+ from daphne .testing import _reinstall_reactor
60
+
61
+ _reinstall_reactor ()
62
+
63
+ from daphne .endpoints import build_endpoint_description_strings
64
+ from daphne .server import Server
65
+
66
+ # Get the application
67
+ application = self .get_application ()
68
+
69
+ # Create the server
70
+ endpoints = build_endpoint_description_strings (
71
+ host = self .host , port = self .port
72
+ )
73
+ self .server = Server (
74
+ application = application ,
75
+ endpoints = endpoints ,
76
+ signal_handlers = False ,
77
+ ready_callable = self ._set_ready ,
78
+ verbosity = 0 ,
79
+ )
80
+
81
+ # Run setup if provided
82
+ if self .setup is not None :
83
+ self .setup ()
84
+
85
+ # Start the server
86
+ self .server .run ()
87
+ except Exception as e :
88
+ self .error = e
89
+ self .is_ready .set ()
90
+ finally :
91
+ connections .close_all ()
92
+
93
+ def _set_ready (self ):
94
+ """Called by Daphne when the server is ready."""
95
+ if self .server .listening_addresses :
96
+ self .port = self .server .listening_addresses [0 ][1 ]
97
+ self .is_ready .set ()
98
+
99
+ def terminate (self ):
100
+ if hasattr (self , "server" ):
101
+ # Stop the ASGI server
102
+ from twisted .internet import reactor
103
+
104
+ if reactor .running :
105
+ reactor .callFromThread (reactor .stop )
106
+ self .join (timeout = 5 )
107
+
108
+
31
109
class ChannelsLiveServerTestCase (TransactionTestCase ):
32
110
"""
33
- Does basically the same as TransactionTestCase but also launches a
34
- live Daphne server in a separate process, so
35
- that the tests may use another test framework, such as Selenium,
36
- instead of the built-in dummy client.
111
+ Do basically the same as TransactionTestCase but also launch a live ASGI
112
+ server in a separate thread so that the tests may use another testing
113
+ framework, such as Selenium for example, instead of the built-in dummy
114
+ client.
115
+ It inherits from TransactionTestCase instead of TestCase because the
116
+ threads don't share the same transactions (unless if using in-memory
117
+ sqlite) and each thread needs to commit all their transactions so that the
118
+ other thread can see the changes.
37
119
"""
38
120
39
121
host = "localhost"
40
- ProtocolServerProcess = DaphneProcess
41
- static_wrapper = ASGIStaticFilesHandler
122
+ port = 0
123
+ server_thread_class = ChannelsLiveServerThread
124
+ static_handler = ASGIStaticFilesHandler
42
125
serve_static = True
43
126
44
- @property
45
- def live_server_url (self ):
46
- return "http://%s:%s" % (self .host , self . _port )
127
+ @classproperty
128
+ def live_server_url (cls ):
129
+ return "http://%s:%s" % (cls .host , cls . server_thread . port )
47
130
48
- @property
49
- def live_server_ws_url (self ):
50
- return "ws://%s:%s" % (self .host , self ._port )
131
+ @classproperty
132
+ def live_server_ws_url (cls ):
133
+ return "ws://%s:%s" % (cls .host , cls .server_thread .port )
134
+
135
+ @classproperty
136
+ def allowed_host (cls ):
137
+ return cls .host
51
138
52
139
@classmethod
53
- def setUpClass (cls ):
54
- for connection in connections .all ():
55
- if cls ._is_in_memory_db (connection ):
56
- raise ImproperlyConfigured (
57
- "ChannelLiveServerTestCase can not be used with in memory databases"
58
- )
140
+ def _make_connections_override (cls ):
141
+ connections_override = {}
142
+ for conn in connections .all ():
143
+ # If using in-memory sqlite databases, pass the connections to
144
+ # the server thread.
145
+ if conn .vendor == "sqlite" and conn .is_in_memory_db ():
146
+ connections_override [conn .alias ] = conn
147
+ return connections_override
59
148
149
+ @classmethod
150
+ def setUpClass (cls ):
60
151
super ().setUpClass ()
61
-
62
- cls ._live_server_modified_settings = modify_settings (
63
- ALLOWED_HOSTS = {"append" : cls .host }
152
+ cls .enterClassContext (
153
+ modify_settings (ALLOWED_HOSTS = {"append" : cls .allowed_host })
64
154
)
65
- cls ._live_server_modified_settings .enable ()
155
+ cls ._start_server_thread ()
156
+
157
+ @classmethod
158
+ def _start_server_thread (cls ):
159
+ connections_override = cls ._make_connections_override ()
160
+ for conn in connections_override .values ():
161
+ # Explicitly enable thread-shareability for this connection.
162
+ conn .inc_thread_sharing ()
163
+
164
+ cls .server_thread = cls ._create_server_thread (connections_override )
165
+ cls .server_thread .daemon = True
166
+ cls .server_thread .start ()
167
+ cls .addClassCleanup (cls ._terminate_thread )
168
+
169
+ # Wait for the live server to be ready
170
+ cls .server_thread .is_ready .wait ()
171
+ if cls .server_thread .error :
172
+ raise cls .server_thread .error
66
173
174
+ @classmethod
175
+ def _create_server_thread (cls , connections_override ):
67
176
get_application = partial (
68
177
make_application ,
69
- static_wrapper = cls .static_wrapper if cls .serve_static else None ,
178
+ static_wrapper = cls .static_handler if cls .serve_static else None ,
70
179
)
71
- cls . _server_process = cls .ProtocolServerProcess (
180
+ return cls .server_thread_class (
72
181
cls .host ,
73
182
get_application ,
183
+ connections_override = connections_override ,
184
+ port = cls .port ,
74
185
setup = set_database_connection ,
75
186
)
76
- cls . _server_process . start ()
77
- while True :
78
- if not cls . _server_process . ready . wait ( timeout = 1 ):
79
- if cls . _server_process . is_alive ():
80
- continue
81
- raise RuntimeError ( "Server stopped" ) from None
82
- break
83
- cls . _port = cls . _server_process . port . value
187
+
188
+ @ classmethod
189
+ def _terminate_thread ( cls ):
190
+ # Terminate the live server's thread.
191
+ cls . server_thread . terminate ()
192
+ # Restore shared connections' non-shareability.
193
+ for conn in cls . server_thread . connections_override . values ():
194
+ conn . dec_thread_sharing ()
84
195
85
196
@classmethod
86
197
def tearDownClass (cls ):
87
- cls ._server_process .terminate ()
88
- cls ._server_process .join ()
89
- cls ._live_server_modified_settings .disable ()
198
+ # The cleanup is now handled by addClassCleanup in _start_server_thread
90
199
super ().tearDownClass ()
91
200
92
201
@classmethod
0 commit comments