Skip to content

Commit 786d827

Browse files
authored
improve thread safety (#1)
1 parent ce902ba commit 786d827

File tree

1 file changed

+55
-44
lines changed

1 file changed

+55
-44
lines changed

hivemind_http_protocol/__init__.py

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import os
44
import os.path
55
import random
6-
import threading
6+
from collections import defaultdict
7+
from queue import Queue
78
from os import makedirs
89
from os.path import exists, join
910
from socket import gethostname
10-
from typing import Dict, Any, Optional, Tuple, List
11-
from collections import defaultdict
11+
from typing import Dict, Any, Optional, Tuple, Union
12+
1213
import pybase64
1314
from OpenSSL import crypto
1415
from ovos_bus_client.session import Session
@@ -28,11 +29,6 @@
2829
from hivemind_plugin_manager.protocols import NetworkProtocol
2930
from poorman_handshake import PasswordHandShake
3031

31-
_LOCK = threading.RLock()
32-
CLIENTS: Dict[str, HiveMindClientConnection] = {}
33-
UNDELIVERED: Dict[str, List[str]] = defaultdict(list) # key: [messages]
34-
UNDELIVERED_BIN: Dict[str, List[str]] = defaultdict(list) # key: [b64_messages]
35-
3632

3733
@dataclasses.dataclass
3834
class HiveMindHttpProtocol(NetworkProtocol):
@@ -46,7 +42,6 @@ class HiveMindHttpProtocol(NetworkProtocol):
4642
hm_protocol: Optional[HiveMindListenerProtocol] = None
4743
callbacks: ClientCallbacks = dataclasses.field(default_factory=ClientCallbacks)
4844

49-
5045
def run(self):
5146
LOG.debug(f"HTTP server config: {self.config}")
5247
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
@@ -70,10 +65,10 @@ def run(self):
7065
cert_file = f"{cert_dir}/{cert_name}.crt"
7166
key_file = f"{cert_dir}/{cert_name}.key"
7267
if not os.path.isfile(key_file):
73-
LOG.info(f"generating self-signed SSL certificate")
68+
LOG.info(f"Generating self-signed SSL certificate")
7469
cert_file, key_file = self.create_self_signed_cert(cert_dir, cert_name)
75-
LOG.debug("using ssl key at " + key_file)
76-
LOG.debug("using ssl certificate at " + cert_file)
70+
LOG.debug("Using SSL key at " + key_file)
71+
LOG.debug("Using SSL certificate at " + cert_file)
7772
ssl_options = {"certfile": cert_file, "keyfile": key_file}
7873
LOG.info(f"HTTPS listener started at port: {port}")
7974
application.listen(port, host, ssl_options=ssl_options)
@@ -105,7 +100,7 @@ def create_self_signed_cert(
105100
makedirs(cert_dir, exist_ok=True)
106101

107102
if not exists(join(cert_dir, cert_file)) or not exists(join(cert_dir, key_file)):
108-
# create a key pair
103+
# Create a key pair
109104
k = crypto.PKey()
110105
k.generate_key(crypto.TYPE_RSA, 2048)
111106

@@ -135,6 +130,11 @@ class HiveMindHttpHandler(web.RequestHandler):
135130
"""Base handler for HTTP requests."""
136131
hm_protocol = None
137132

133+
# Class-level properties for managing client state and message queues
134+
clients: Dict[str, HiveMindClientConnection] = {}
135+
undelivered: Dict[str, Queue] = defaultdict(Queue) # Non-binary messages
136+
undelivered_bin: Dict[str, Queue] = defaultdict(Queue) # Binary messages
137+
138138
def decode_auth(self):
139139
auth = self.get_argument("authorization", "")
140140
if not auth:
@@ -144,26 +144,22 @@ def decode_auth(self):
144144
userpass_decoded = pybase64.b64decode(userpass_encoded).decode("utf-8")
145145
return userpass_decoded.split(":")
146146

147-
def get_client(self, useragent, key, cache = True) -> Optional[HiveMindClientConnection]:
148-
global CLIENTS, UNDELIVERED
149-
150-
if cache and key in CLIENTS:
151-
return CLIENTS[key]
147+
def get_client(self, useragent, key, cache=True) -> Optional[HiveMindClientConnection]:
148+
if cache and key in self.clients:
149+
return self.clients[key]
152150

153-
def do_send(payload: str, is_bin: bool):
154-
with _LOCK:
155-
if is_bin:
156-
payload = pybase64.b64encode(payload).decode("utf-8")
157-
UNDELIVERED_BIN[key].append(payload)
158-
else:
159-
UNDELIVERED[key].append(payload)
151+
def do_send(payload: Union[bytes, str], is_bin: bool):
152+
if is_bin:
153+
payload = pybase64.b64encode(payload).decode("utf-8")
154+
self.undelivered_bin[key].put(payload)
155+
else:
156+
self.undelivered[key].put(payload)
160157

161158
def do_disconnect():
162-
with _LOCK:
163-
if key in UNDELIVERED:
164-
UNDELIVERED.pop(key)
165-
if key in CLIENTS:
166-
CLIENTS.pop(key)
159+
if key in self.undelivered:
160+
self.undelivered.pop(key)
161+
if key in self.clients:
162+
self.clients.pop(key)
167163

168164
client = HiveMindClientConnection(
169165
key=key,
@@ -176,7 +172,7 @@ def do_disconnect():
176172
self.hm_protocol.db.sync()
177173
user = self.hm_protocol.db.get_client_by_api_key(key)
178174
if not user:
179-
LOG.error("Client provided an invalid api key")
175+
LOG.error("Client provided an invalid Access key")
180176
self.hm_protocol.handle_invalid_key_connected(client)
181177
return None
182178

@@ -195,7 +191,7 @@ def do_disconnect():
195191

196192
client.node_type = HiveMindNodeType.NODE # TODO . placeholder
197193
if cache:
198-
CLIENTS[key] = client
194+
self.clients[key] = client
199195
return client
200196

201197

@@ -232,18 +228,16 @@ async def post(self):
232228

233229
class DisconnectHandler(HiveMindHttpHandler):
234230
async def post(self):
235-
global CLIENTS
236231

237232
try:
238233
useragent, key = self.decode_auth()
239234
if not key:
240235
self.write({"error": "Missing authorization"})
241236
return
242-
if key in CLIENTS:
237+
if key in HiveMindHttpHandler.clients:
243238
client = self.get_client(useragent, key)
244239
LOG.info(f"disconnecting client: {client.peer}")
245240
self.hm_protocol.handle_client_disconnected(client)
246-
CLIENTS.pop(key)
247241
self.write({"status": "Disconnected"})
248242
else:
249243
self.write({"error": "Already Disconnected"})
@@ -261,7 +255,7 @@ async def post(self):
261255
self.write({"error": "Missing authorization"})
262256
return
263257
# refuse if connect wasnt called first
264-
if key not in CLIENTS:
258+
if key not in HiveMindHttpHandler.clients:
265259
self.write({"error": "Client is not connected"})
266260
return
267261

@@ -300,13 +294,21 @@ async def get(self):
300294
return
301295

302296
# refuse if connect wasnt called first
303-
if key not in CLIENTS:
297+
if key not in HiveMindHttpHandler.clients:
304298
self.write({"error": "Client is not connected"})
305299
return
306300

307-
# send non-binary payloads to the client
308-
messages = UNDELIVERED[key]
309-
UNDELIVERED[key] = []
301+
messages = []
302+
queue = HiveMindHttpHandler.undelivered[key]
303+
304+
# Retrieve all messages from the queue
305+
while not queue.empty():
306+
try:
307+
message = queue.get_nowait()
308+
messages.append(message)
309+
except Exception as e:
310+
# Handle unexpected errors (unlikely with get_nowait)
311+
break
310312
self.write({"status": "messages retrieved", "messages": messages})
311313
except Exception as e:
312314
LOG.error(f"Retrieving messages failed: {e}")
@@ -324,13 +326,22 @@ async def get(self):
324326
return
325327

326328
# refuse if connect wasnt called first
327-
if key not in CLIENTS:
329+
if key not in HiveMindHttpHandler.clients:
328330
self.write({"error": "Client is not connected"})
329331
return
330332

331-
# send non-binary payloads to the client
332-
messages = UNDELIVERED_BIN[key]
333-
UNDELIVERED_BIN[key] = []
333+
messages = []
334+
queue = HiveMindHttpHandler.undelivered_bin[key]
335+
336+
# Retrieve all messages from the queue
337+
while not queue.empty():
338+
try:
339+
message = queue.get_nowait()
340+
messages.append(message)
341+
except Exception as e:
342+
# Handle unexpected errors (unlikely with get_nowait)
343+
break
344+
334345
self.write({"status": "messages retrieved", "b64_messages": messages})
335346
except Exception as e:
336347
LOG.error(f"Retrieving messages failed: {e}")

0 commit comments

Comments
 (0)