Skip to content

Commit 832a493

Browse files
committed
Add threading to host-side comms server
1 parent 45cdc31 commit 832a493

File tree

2 files changed

+77
-56
lines changed

2 files changed

+77
-56
lines changed

lgl_android_install.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -606,19 +606,21 @@ def configure_server(conn: ADBConnect,
606606
profile_dir: The desired output directory path for timeline. Existing
607607
files in the directory may be overwritten.
608608
'''
609+
verbose = False
610+
609611
# Create a server instance
610-
instance = server.CommsServer(0)
612+
instance = server.CommsServer(0, verbose)
611613

612614
if timeline_file:
613615
# Import late to avoid pulling in transitive deps when unused
614616
from lglpy.comms import service_gpu_timeline
615-
service_tl = service_gpu_timeline.GPUTimelineService(timeline_file)
617+
service_tl = service_gpu_timeline.GPUTimelineService(timeline_file, verbose)
616618
instance.register_endpoint(service_tl)
617619

618620
if profile_dir:
619621
# Import late to avoid pulling in transitive deps when unused
620622
from lglpy.comms import service_gpu_profile
621-
service_prof = service_gpu_profile.GPUProfileService(profile_dir)
623+
service_prof = service_gpu_profile.GPUProfileService(profile_dir, verbose)
622624
instance.register_endpoint(service_prof)
623625

624626
# Start it running

lglpy/comms/server.py

Lines changed: 72 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,23 @@
2222
# -----------------------------------------------------------------------------
2323

2424
'''
25-
This module implements the server-side communications module that can accept
26-
client connections from a layer driver, and dispatch messages to registered
27-
service handler in the server.
28-
29-
This module currently only accepts a single connection at a time and message
30-
handling is synchronous inside the server. It is therefore not possible to
31-
implement pseudo-host-driven event loops if the layer is using multiple
32-
services concurrently - this needs threads per service.
25+
This module implements the server-side of the communications module that can
26+
accept connections from client layer drivers running on the device. The
27+
protocol is service-based, and the server will dispatch messages to the
28+
registered service handler for each message channel.
29+
30+
The server is multi-threaded, allowing multiple layers to concurrently access
31+
networked services provided by host-side implementations. However, within each
32+
client connection messages are handled synchronously by a single worker thread.
33+
It is therefore not possible to implement pseudo-host-driven event loops if a
34+
layer is using multiple services concurrently - this needs threads per service
35+
endpoint which is not yet implemented.
3336
'''
3437

3538
import enum
3639
import socket
3740
import struct
41+
import threading
3842
from typing import Any, Optional
3943

4044

@@ -143,7 +147,7 @@ class CommsServer:
143147
Class listening for client connection from a layer and handling messages.
144148
145149
This implementation is designed to run in a thread, so has a run method
146-
that will setup and listen on the server socket.q
150+
that will setup and listen on the server socket.
147151
148152
This implementation only handles a single layer connection at a time, but
149153
can handle multiple connections serially without restarting.
@@ -173,7 +177,6 @@ def __init__(self, port: int, verbose: bool = False):
173177
self.register_endpoint(self)
174178

175179
self.shutdown = False
176-
self.sockd = None # type: Optional[socket.socket]
177180

178181
self.sockl = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
179182
self.sockl.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@@ -185,6 +188,9 @@ def __init__(self, port: int, verbose: bool = False):
185188
# Work out which port was assigned if not user-defined
186189
self.port = self.sockl.getsockname()[1]
187190

191+
# Pool of worker threads
192+
self.workers: list[threading.Thread] = []
193+
188194
def register_endpoint(self, endpoint: Any) -> int:
189195
'''
190196
Register a new service endpoint with the server.
@@ -235,55 +241,55 @@ def run(self) -> None:
235241
if self.verbose:
236242
print('Waiting for client connection')
237243
try:
238-
self.sockd, _ = self.sockl.accept()
244+
sockd, _ = self.sockl.accept()
239245
if self.verbose:
240246
print(' + Client connected')
241247

242-
self.run_client()
243-
244-
if self.verbose:
245-
print(' + Client disconnected')
246-
self.sockd.close()
247-
self.sockd = None
248-
249-
except ClientDropped:
250-
if self.verbose:
251-
print(' + Client disconnected')
252-
if self.sockd:
253-
self.sockd.close()
254-
self.sockd = None
248+
thread = threading.Thread(target=self.run_client, args=(sockd,))
249+
self.workers.append(thread)
250+
thread.start()
255251

256252
except OSError:
257253
continue
258254

259255
self.sockl.close()
260256

261-
def run_client(self) -> None:
257+
def run_client(self, sockd: socket.socket) -> None:
262258
'''
263259
Enter client message handler run loop.
264260
265261
Raises:
266262
ClientDropped: The client disconnected from the socket.
267263
'''
268-
while not self.shutdown:
269-
# Read the header
270-
data = self.receive_data(Message.HEADER_LEN)
271-
message = Message(data)
272-
273-
# Read the payload if there is one
274-
if message.payload_size:
275-
data = self.receive_data(message.payload_size)
276-
message.add_payload(data)
277-
278-
# Dispatch to a service handler
279-
endpoint = self.endpoints[message.endpoint_id]
280-
response = endpoint.handle_message(message)
264+
try:
265+
while not self.shutdown:
266+
# Read the header
267+
data = self.receive_data(sockd, Message.HEADER_LEN)
268+
message = Message(data)
269+
270+
# Read the payload if there is one
271+
if message.payload_size:
272+
data = self.receive_data(sockd, message.payload_size)
273+
message.add_payload(data)
274+
275+
# Dispatch to a service handler
276+
endpoint = self.endpoints[message.endpoint_id]
277+
response = endpoint.handle_message(message)
278+
279+
# Send a response for all TX_RX messages
280+
if message.message_type == MessageType.TX_RX:
281+
header = Response(message, response)
282+
self.send_data(sockd, header.get_header())
283+
self.send_data(sockd, response)
284+
285+
except ClientDropped:
286+
pass
287+
288+
finally:
289+
if self.verbose:
290+
print(' + Client disconnected')
281291

282-
# Send a response for all TX_RX messages
283-
if message.message_type == MessageType.TX_RX:
284-
header = Response(message, response)
285-
self.send_data(header.get_header())
286-
self.send_data(response)
292+
sockd.close()
287293

288294
def stop(self) -> None:
289295
'''
@@ -294,14 +300,29 @@ def stop(self) -> None:
294300
if self.sockl is not None:
295301
self.sockl.close()
296302

297-
if self.sockd is not None:
298-
self.sockd.shutdown(socket.SHUT_RDWR)
303+
self.wait_for_workers()
299304

300-
def receive_data(self, size: int) -> bytes:
305+
def wait_for_workers(self) -> None:
306+
'''
307+
Wait for workers
308+
'''
309+
# Wait for each worker - slightly convoluted logic as workers threads
310+
# remove themselves from the list as they complete so there is a race
311+
# between self.workers test and testing a worker in the loop
312+
while self.workers:
313+
try:
314+
worker = self.workers[0]
315+
worker.join()
316+
except IndexError:
317+
pass
318+
319+
@staticmethod
320+
def receive_data(sockd: socket.socket, size: int) -> bytes:
301321
'''
302322
Fetch a fixed size packet from the socket.
303323
304324
Args:
325+
sockd: The data socket.
305326
size: The length of the packet in bytes.
306327
307328
Returns:
@@ -310,31 +331,29 @@ def receive_data(self, size: int) -> bytes:
310331
Raises:
311332
ClientDropped: The client disconnected from the socket.
312333
'''
313-
assert self.sockd is not None
314-
315334
data = b''
316335
while len(data) < size:
317-
new_data = self.sockd.recv(size - len(data))
336+
new_data = sockd.recv(size - len(data))
318337
if not new_data:
319338
raise ClientDropped()
320339
data = data + new_data
321340

322341
return data
323342

324-
def send_data(self, data: bytes) -> None:
343+
@staticmethod
344+
def send_data(sockd: socket.socket, data: bytes) -> None:
325345
'''
326346
Send a fixed size packet to the socket.
327347
328348
Args:
349+
sockd: The data socket.
329350
data: The binary data to send.
330351
331352
Raises:
332353
ClientDropped: The client disconnected from the socket.
333354
'''
334-
assert self.sockd is not None
335-
336355
while len(data):
337-
sent_bytes = self.sockd.send(data)
356+
sent_bytes = sockd.send(data)
338357
if not sent_bytes:
339358
raise ClientDropped()
340359
data = data[sent_bytes:]

0 commit comments

Comments
 (0)