1
+ import threading
1
2
from functools import partial
2
3
3
- from daphne .testing import DaphneProcess
4
4
from django .contrib .staticfiles .handlers import ASGIStaticFilesHandler
5
- from django .core .exceptions import ImproperlyConfigured
6
5
from django .db import connections
7
6
from django .db .backends .base .creation import TEST_DATABASE_PREFIX
8
7
from django .test .testcases import TransactionTestCase
9
8
from django .test .utils import modify_settings
9
+ from django .utils .functional import classproperty
10
10
11
11
from channels .routing import get_default_application
12
12
@@ -28,65 +28,172 @@ def set_database_connection():
28
28
settings .DATABASES ["default" ]["NAME" ] = test_db_name
29
29
30
30
31
+ class ChannelsLiveServerThread (threading .Thread ):
32
+ """Thread for running a live ASGI server while the tests are running."""
33
+
34
+ def __init__ (self , host , get_application , connections_override = None , port = 0 , setup = None ):
35
+ self .host = host
36
+ self .port = port
37
+ self .get_application = get_application
38
+ self .connections_override = connections_override
39
+ self .setup = setup
40
+ self .is_ready = threading .Event ()
41
+ self .error = None
42
+ super ().__init__ ()
43
+
44
+ def run (self ):
45
+ """
46
+ Set up the live server and databases, and then loop over handling
47
+ ASGI requests.
48
+ """
49
+ if self .connections_override :
50
+ # Override this thread's database connections with the ones
51
+ # provided by the main thread.
52
+ for alias , conn in self .connections_override .items ():
53
+ connections [alias ] = conn
54
+
55
+ try :
56
+ # Reinstall the reactor for this thread (same as DaphneProcess)
57
+ from daphne .testing import _reinstall_reactor
58
+ _reinstall_reactor ()
59
+
60
+ from twisted .internet import reactor
61
+ from daphne .endpoints import build_endpoint_description_strings
62
+ from daphne .server import Server
63
+
64
+ # Get the application
65
+ application = self .get_application ()
66
+
67
+ # Create the server
68
+ endpoints = build_endpoint_description_strings (
69
+ host = self .host , port = self .port
70
+ )
71
+ self .server = Server (
72
+ application = application ,
73
+ endpoints = endpoints ,
74
+ signal_handlers = False ,
75
+ ready_callable = self ._set_ready ,
76
+ verbosity = 0 ,
77
+ )
78
+
79
+ # Run setup if provided
80
+ if self .setup is not None :
81
+ self .setup ()
82
+
83
+ # Start the server
84
+ self .server .run ()
85
+ except Exception as e :
86
+ self .error = e
87
+ self .is_ready .set ()
88
+ finally :
89
+ connections .close_all ()
90
+
91
+ def _set_ready (self ):
92
+ """Called by Daphne when the server is ready."""
93
+ if self .server .listening_addresses :
94
+ self .port = self .server .listening_addresses [0 ][1 ]
95
+ self .is_ready .set ()
96
+
97
+ def terminate (self ):
98
+ if hasattr (self , "server" ):
99
+ # Stop the ASGI server
100
+ from twisted .internet import reactor
101
+
102
+ if reactor .running :
103
+ reactor .callFromThread (reactor .stop )
104
+ self .join (timeout = 5 )
105
+
106
+
31
107
class ChannelsLiveServerTestCase (TransactionTestCase ):
32
108
"""
33
- Does basically the same as TransactionTestCase but also launches a
34
- live Daphne server in a separate process, so
35
- that the tests may use another test framework, such as Selenium,
36
- instead of the built-in dummy client.
109
+ Do basically the same as TransactionTestCase but also launch a live ASGI
110
+ server in a separate thread so that the tests may use another testing
111
+ framework, such as Selenium for example, instead of the built-in dummy
112
+ client.
113
+ It inherits from TransactionTestCase instead of TestCase because the
114
+ threads don't share the same transactions (unless if using in-memory
115
+ sqlite) and each thread needs to commit all their transactions so that the
116
+ other thread can see the changes.
37
117
"""
38
118
39
119
host = "localhost"
40
- ProtocolServerProcess = DaphneProcess
41
- static_wrapper = ASGIStaticFilesHandler
120
+ port = 0
121
+ server_thread_class = ChannelsLiveServerThread
122
+ static_handler = ASGIStaticFilesHandler
42
123
serve_static = True
43
124
44
- @property
45
- def live_server_url (self ):
46
- return "http://%s:%s" % (self .host , self . _port )
125
+ @classproperty
126
+ def live_server_url (cls ):
127
+ return "http://%s:%s" % (cls .host , cls . server_thread . port )
47
128
48
- @property
49
- def live_server_ws_url (self ):
50
- return "ws://%s:%s" % (self .host , self ._port )
129
+ @classproperty
130
+ def live_server_ws_url (cls ):
131
+ return "ws://%s:%s" % (cls .host , cls .server_thread .port )
132
+
133
+ @classproperty
134
+ def allowed_host (cls ):
135
+ return cls .host
51
136
52
137
@classmethod
53
- def setUpClass (cls ):
54
- for connection in connections .all ():
55
- if cls ._is_in_memory_db (connection ):
56
- raise ImproperlyConfigured (
57
- "ChannelLiveServerTestCase can not be used with in memory databases"
58
- )
138
+ def _make_connections_override (cls ):
139
+ connections_override = {}
140
+ for conn in connections .all ():
141
+ # If using in-memory sqlite databases, pass the connections to
142
+ # the server thread.
143
+ if conn .vendor == "sqlite" and conn .is_in_memory_db ():
144
+ connections_override [conn .alias ] = conn
145
+ return connections_override
59
146
147
+ @classmethod
148
+ def setUpClass (cls ):
60
149
super ().setUpClass ()
61
-
62
- cls ._live_server_modified_settings = modify_settings (
63
- ALLOWED_HOSTS = {"append" : cls .host }
150
+ cls .enterClassContext (
151
+ modify_settings (ALLOWED_HOSTS = {"append" : cls .allowed_host })
64
152
)
65
- cls ._live_server_modified_settings .enable ()
153
+ cls ._start_server_thread ()
154
+
155
+ @classmethod
156
+ def _start_server_thread (cls ):
157
+ connections_override = cls ._make_connections_override ()
158
+ for conn in connections_override .values ():
159
+ # Explicitly enable thread-shareability for this connection.
160
+ conn .inc_thread_sharing ()
161
+
162
+ cls .server_thread = cls ._create_server_thread (connections_override )
163
+ cls .server_thread .daemon = True
164
+ cls .server_thread .start ()
165
+ cls .addClassCleanup (cls ._terminate_thread )
166
+
167
+ # Wait for the live server to be ready
168
+ cls .server_thread .is_ready .wait ()
169
+ if cls .server_thread .error :
170
+ raise cls .server_thread .error
66
171
172
+ @classmethod
173
+ def _create_server_thread (cls , connections_override ):
67
174
get_application = partial (
68
175
make_application ,
69
- static_wrapper = cls .static_wrapper if cls .serve_static else None ,
176
+ static_wrapper = cls .static_handler if cls .serve_static else None ,
70
177
)
71
- cls . _server_process = cls .ProtocolServerProcess (
178
+ return cls .server_thread_class (
72
179
cls .host ,
73
180
get_application ,
181
+ connections_override = connections_override ,
182
+ port = cls .port ,
74
183
setup = set_database_connection ,
75
184
)
76
- cls . _server_process . start ()
77
- while True :
78
- if not cls . _server_process . ready . wait ( timeout = 1 ):
79
- if cls . _server_process . is_alive ():
80
- continue
81
- raise RuntimeError ( "Server stopped" ) from None
82
- break
83
- cls . _port = cls . _server_process . port . value
185
+
186
+ @ classmethod
187
+ def _terminate_thread ( cls ):
188
+ # Terminate the live server's thread.
189
+ cls . server_thread . terminate ()
190
+ # Restore shared connections' non-shareability.
191
+ for conn in cls . server_thread . connections_override . values ():
192
+ conn . dec_thread_sharing ()
84
193
85
194
@classmethod
86
195
def tearDownClass (cls ):
87
- cls ._server_process .terminate ()
88
- cls ._server_process .join ()
89
- cls ._live_server_modified_settings .disable ()
196
+ # The cleanup is now handled by addClassCleanup in _start_server_thread
90
197
super ().tearDownClass ()
91
198
92
199
@classmethod
0 commit comments