33import os
44import os .path
55import random
6- import threading
6+ from collections import defaultdict
7+ from queue import Queue
78from os import makedirs
89from os .path import exists , join
910from socket import gethostname
10- from typing import Dict , Any , Optional , Tuple , List
11- from collections import defaultdict
11+ from typing import Dict , Any , Optional , Tuple , Union
12+
1213import pybase64
1314from OpenSSL import crypto
1415from ovos_bus_client .session import Session
2829from hivemind_plugin_manager .protocols import NetworkProtocol
2930from poorman_handshake import PasswordHandShake
3031
31- _LOCK = threading .RLock ()
32- CLIENTS : Dict [str , HiveMindClientConnection ] = {}
33- UNDELIVERED : Dict [str , List [str ]] = defaultdict (list ) # key: [messages]
34- UNDELIVERED_BIN : Dict [str , List [str ]] = defaultdict (list ) # key: [b64_messages]
35-
3632
3733@dataclasses .dataclass
3834class HiveMindHttpProtocol (NetworkProtocol ):
@@ -46,7 +42,6 @@ class HiveMindHttpProtocol(NetworkProtocol):
4642 hm_protocol : Optional [HiveMindListenerProtocol ] = None
4743 callbacks : ClientCallbacks = dataclasses .field (default_factory = ClientCallbacks )
4844
49-
5045 def run (self ):
5146 LOG .debug (f"HTTP server config: { self .config } " )
5247 asyncio .set_event_loop_policy (AnyThreadEventLoopPolicy ())
@@ -70,10 +65,10 @@ def run(self):
7065 cert_file = f"{ cert_dir } /{ cert_name } .crt"
7166 key_file = f"{ cert_dir } /{ cert_name } .key"
7267 if not os .path .isfile (key_file ):
73- LOG .info (f"generating self-signed SSL certificate" )
68+ LOG .info (f"Generating self-signed SSL certificate" )
7469 cert_file , key_file = self .create_self_signed_cert (cert_dir , cert_name )
75- LOG .debug ("using ssl key at " + key_file )
76- LOG .debug ("using ssl certificate at " + cert_file )
70+ LOG .debug ("Using SSL key at " + key_file )
71+ LOG .debug ("Using SSL certificate at " + cert_file )
7772 ssl_options = {"certfile" : cert_file , "keyfile" : key_file }
7873 LOG .info (f"HTTPS listener started at port: { port } " )
7974 application .listen (port , host , ssl_options = ssl_options )
@@ -105,7 +100,7 @@ def create_self_signed_cert(
105100 makedirs (cert_dir , exist_ok = True )
106101
107102 if not exists (join (cert_dir , cert_file )) or not exists (join (cert_dir , key_file )):
108- # create a key pair
103+ # Create a key pair
109104 k = crypto .PKey ()
110105 k .generate_key (crypto .TYPE_RSA , 2048 )
111106
@@ -135,6 +130,11 @@ class HiveMindHttpHandler(web.RequestHandler):
135130 """Base handler for HTTP requests."""
136131 hm_protocol = None
137132
133+ # Class-level properties for managing client state and message queues
134+ clients : Dict [str , HiveMindClientConnection ] = {}
135+ undelivered : Dict [str , Queue ] = defaultdict (Queue ) # Non-binary messages
136+ undelivered_bin : Dict [str , Queue ] = defaultdict (Queue ) # Binary messages
137+
138138 def decode_auth (self ):
139139 auth = self .get_argument ("authorization" , "" )
140140 if not auth :
@@ -144,26 +144,22 @@ def decode_auth(self):
144144 userpass_decoded = pybase64 .b64decode (userpass_encoded ).decode ("utf-8" )
145145 return userpass_decoded .split (":" )
146146
147- def get_client (self , useragent , key , cache = True ) -> Optional [HiveMindClientConnection ]:
148- global CLIENTS , UNDELIVERED
149-
150- if cache and key in CLIENTS :
151- return CLIENTS [key ]
147+ def get_client (self , useragent , key , cache = True ) -> Optional [HiveMindClientConnection ]:
148+ if cache and key in self .clients :
149+ return self .clients [key ]
152150
153- def do_send (payload : str , is_bin : bool ):
154- with _LOCK :
155- if is_bin :
156- payload = pybase64 .b64encode (payload ).decode ("utf-8" )
157- UNDELIVERED_BIN [key ].append (payload )
158- else :
159- UNDELIVERED [key ].append (payload )
151+ def do_send (payload : Union [bytes , str ], is_bin : bool ):
152+ if is_bin :
153+ payload = pybase64 .b64encode (payload ).decode ("utf-8" )
154+ self .undelivered_bin [key ].put (payload )
155+ else :
156+ self .undelivered [key ].put (payload )
160157
161158 def do_disconnect ():
162- with _LOCK :
163- if key in UNDELIVERED :
164- UNDELIVERED .pop (key )
165- if key in CLIENTS :
166- CLIENTS .pop (key )
159+ if key in self .undelivered :
160+ self .undelivered .pop (key )
161+ if key in self .clients :
162+ self .clients .pop (key )
167163
168164 client = HiveMindClientConnection (
169165 key = key ,
@@ -176,7 +172,7 @@ def do_disconnect():
176172 self .hm_protocol .db .sync ()
177173 user = self .hm_protocol .db .get_client_by_api_key (key )
178174 if not user :
179- LOG .error ("Client provided an invalid api key" )
175+ LOG .error ("Client provided an invalid Access key" )
180176 self .hm_protocol .handle_invalid_key_connected (client )
181177 return None
182178
@@ -195,7 +191,7 @@ def do_disconnect():
195191
196192 client .node_type = HiveMindNodeType .NODE # TODO . placeholder
197193 if cache :
198- CLIENTS [key ] = client
194+ self . clients [key ] = client
199195 return client
200196
201197
@@ -232,18 +228,16 @@ async def post(self):
232228
233229class DisconnectHandler (HiveMindHttpHandler ):
234230 async def post (self ):
235- global CLIENTS
236231
237232 try :
238233 useragent , key = self .decode_auth ()
239234 if not key :
240235 self .write ({"error" : "Missing authorization" })
241236 return
242- if key in CLIENTS :
237+ if key in HiveMindHttpHandler . clients :
243238 client = self .get_client (useragent , key )
244239 LOG .info (f"disconnecting client: { client .peer } " )
245240 self .hm_protocol .handle_client_disconnected (client )
246- CLIENTS .pop (key )
247241 self .write ({"status" : "Disconnected" })
248242 else :
249243 self .write ({"error" : "Already Disconnected" })
@@ -261,7 +255,7 @@ async def post(self):
261255 self .write ({"error" : "Missing authorization" })
262256 return
263257 # refuse if connect wasnt called first
264- if key not in CLIENTS :
258+ if key not in HiveMindHttpHandler . clients :
265259 self .write ({"error" : "Client is not connected" })
266260 return
267261
@@ -300,13 +294,21 @@ async def get(self):
300294 return
301295
302296 # refuse if connect wasnt called first
303- if key not in CLIENTS :
297+ if key not in HiveMindHttpHandler . clients :
304298 self .write ({"error" : "Client is not connected" })
305299 return
306300
307- # send non-binary payloads to the client
308- messages = UNDELIVERED [key ]
309- UNDELIVERED [key ] = []
301+ messages = []
302+ queue = HiveMindHttpHandler .undelivered [key ]
303+
304+ # Retrieve all messages from the queue
305+ while not queue .empty ():
306+ try :
307+ message = queue .get_nowait ()
308+ messages .append (message )
309+ except Exception as e :
310+ # Handle unexpected errors (unlikely with get_nowait)
311+ break
310312 self .write ({"status" : "messages retrieved" , "messages" : messages })
311313 except Exception as e :
312314 LOG .error (f"Retrieving messages failed: { e } " )
@@ -324,13 +326,22 @@ async def get(self):
324326 return
325327
326328 # refuse if connect wasnt called first
327- if key not in CLIENTS :
329+ if key not in HiveMindHttpHandler . clients :
328330 self .write ({"error" : "Client is not connected" })
329331 return
330332
331- # send non-binary payloads to the client
332- messages = UNDELIVERED_BIN [key ]
333- UNDELIVERED_BIN [key ] = []
333+ messages = []
334+ queue = HiveMindHttpHandler .undelivered_bin [key ]
335+
336+ # Retrieve all messages from the queue
337+ while not queue .empty ():
338+ try :
339+ message = queue .get_nowait ()
340+ messages .append (message )
341+ except Exception as e :
342+ # Handle unexpected errors (unlikely with get_nowait)
343+ break
344+
334345 self .write ({"status" : "messages retrieved" , "b64_messages" : messages })
335346 except Exception as e :
336347 LOG .error (f"Retrieving messages failed: { e } " )
0 commit comments