1
1
# -*- coding: utf-8 -*-
2
- """PyMilo RESTFull Communication Mediums."""
2
+ """PyMilo Communication Mediums."""
3
3
import json
4
+ import asyncio
4
5
import uvicorn
5
6
import requests
7
+ import websockets
8
+ from enum import Enum
6
9
from pydantic import BaseModel
7
- from fastapi import FastAPI , Request
10
+ from fastapi import FastAPI , Request , WebSocket , WebSocketDisconnect
8
11
from .interfaces import ClientCommunicator
12
+ from .param import PYMILO_INVALID_URL , PYMILO_CLIENT_WEBSOCKET_NOT_CONNECTED
13
+ from .util import validate_websocket_url , validate_http_url
9
14
10
15
11
16
class RESTClientCommunicator (ClientCommunicator ):
@@ -19,6 +24,9 @@ def __init__(self, server_url):
19
24
:type server_url: str
20
25
:return: an instance of the Pymilo RESTClientCommunicator class
21
26
"""
27
+ is_valid , server_url = validate_http_url (server_url )
28
+ if not is_valid :
29
+ raise Exception (PYMILO_INVALID_URL )
22
30
self ._server_url = server_url
23
31
self .session = requests .Session ()
24
32
retries = requests .adapters .Retry (
@@ -96,10 +104,10 @@ def __init__(
96
104
:type port: int
97
105
:return: an instance of the Pymilo RESTServerCommunicator class
98
106
"""
99
- self .app = FastAPI ()
107
+ self ._ps = ps
100
108
self .host = host
101
109
self .port = port
102
- self ._ps = ps
110
+ self .app = FastAPI ()
103
111
self .setup_routes ()
104
112
105
113
def setup_routes (self ):
@@ -188,3 +196,302 @@ def parse(self, body):
188
196
def run (self ):
189
197
"""Run internal fastapi server."""
190
198
uvicorn .run (self .app , host = self .host , port = self .port )
199
+
200
+
201
+ class WebSocketClientCommunicator (ClientCommunicator ):
202
+ """Facilitate working with the communication medium from the client side for the WebSocket protocol."""
203
+
204
+ def __init__ (
205
+ self ,
206
+ server_url : str = "ws://127.0.0.1:8000"
207
+ ):
208
+ """
209
+ Initialize the WebSocketClientCommunicator instance.
210
+
211
+ :param server_url: the WebSocket server URL to connect to.
212
+ :type server_url: str
213
+ :return: an instance of the Pymilo WebSocketClientCommunicator class
214
+ """
215
+ is_valid , url = validate_websocket_url (server_url )
216
+ if not is_valid :
217
+ raise Exception (PYMILO_INVALID_URL )
218
+ self .server_url = url
219
+ self .websocket = None
220
+ self .connection_established = asyncio .Event () # Event to signal connection status
221
+ # check for even loop existance
222
+ if asyncio ._get_running_loop () is None :
223
+ self .loop = asyncio .new_event_loop ()
224
+ asyncio .set_event_loop (self .loop )
225
+ else :
226
+ self .loop = asyncio .get_event_loop ()
227
+ self .loop .run_until_complete (self .connect ())
228
+
229
+ def is_socket_closed (self ):
230
+ """
231
+ Check if the WebSocket connection is closed.
232
+
233
+ :return: `True` if the WebSocket connection is closed or uninitialized, `False` otherwise.
234
+ """
235
+ if self .websocket is None :
236
+ return True
237
+ elif hasattr (self .websocket , "closed" ): # For older versions
238
+ return self .websocket .closed
239
+ elif hasattr (self .websocket , "state" ): # For newer versions
240
+ return self .websocket .state is websockets .protocol .State .CLOSED
241
+
242
+ async def connect (self ):
243
+ """Establish a WebSocket connection with the server."""
244
+ if self .is_socket_closed ():
245
+ self .websocket = await websockets .connect (self .server_url )
246
+ print ("Connected to the WebSocket server." )
247
+ self .connection_established .set ()
248
+
249
+ async def disconnect (self ):
250
+ """Close the WebSocket connection."""
251
+ if self .websocket :
252
+ await self .websocket .close ()
253
+
254
+ async def send_message (self , action : str , payload : dict ) -> dict :
255
+ """
256
+ Send a message to the WebSocket server.
257
+
258
+ :param action: the type of action to perform (e.g., 'download', 'upload').
259
+ :type action: str
260
+ :param payload: the payload associated with the action.
261
+ :type payload: dict
262
+ :return: the server's response as a JSON object.
263
+ """
264
+ await self .connection_established .wait ()
265
+
266
+ if self .is_socket_closed ():
267
+ raise RuntimeError (PYMILO_CLIENT_WEBSOCKET_NOT_CONNECTED )
268
+
269
+ message = json .dumps ({"action" : action , "payload" : payload })
270
+ await self .websocket .send (message )
271
+ response = await self .websocket .recv ()
272
+ return json .loads (response )
273
+
274
+ def download (self , payload : dict ) -> dict :
275
+ """
276
+ Request the remote ML model to download.
277
+
278
+ :param payload: the payload for the download request.
279
+ :type payload: dict
280
+ :return: the downloaded model data.
281
+ """
282
+ response = self .loop .run_until_complete (
283
+ self .send_message ("download" , payload )
284
+ )
285
+ return response .get ("payload" )
286
+
287
+ def upload (self , payload : dict ) -> bool :
288
+ """
289
+ Upload the local ML model to the remote server.
290
+
291
+ :param payload: the payload for the upload request.
292
+ :type payload: dict
293
+ :return: true if the upload request is acknowledged.
294
+ """
295
+ response = self .loop .run_until_complete (
296
+ self .send_message ("upload" , payload )
297
+ )
298
+ return response .get ("message" ) == "Upload request received."
299
+
300
+ def attribute_call (self , payload : dict ) -> dict :
301
+ """
302
+ Delegate the requested attribute call to the remote server.
303
+
304
+ :param payload: the payload containing attribute call details.
305
+ :type payload: dict
306
+ :return: the server's response to the attribute call.
307
+ """
308
+ response = self .loop .run_until_complete (
309
+ self .send_message ("attribute_call" , payload )
310
+ )
311
+ return response
312
+
313
+ def attribute_type (self , payload : dict ) -> dict :
314
+ """
315
+ Identify the attribute type of the requested attribute.
316
+
317
+ :param payload: the payload containing attribute type request.
318
+ :type payload: dict
319
+ :return: the server's response with the attribute type.
320
+ """
321
+ response = self .loop .run_until_complete (
322
+ self .send_message ("attribute_type" , payload )
323
+ )
324
+ return response
325
+
326
+
327
+ class WebSocketServerCommunicator :
328
+ """Facilitate working with the communication medium from the server side for the WebSocket protocol."""
329
+
330
+ def __init__ (
331
+ self ,
332
+ ps ,
333
+ host : str = "127.0.0.1" ,
334
+ port : int = 8000 ,
335
+ ):
336
+ """
337
+ Initialize the WebSocketServerCommunicator instance.
338
+
339
+ :param ps: reference to the PyMilo server.
340
+ :type ps: pymilo.streaming.PymiloServer
341
+ :param host: the WebSocket server host address.
342
+ :type host: str
343
+ :param port: the WebSocket server port.
344
+ :type port: int
345
+ :return: an instance of the WebSocketServerCommunicator class.
346
+ """
347
+ self ._ps = ps
348
+ self .host = host
349
+ self .port = port
350
+ self .app = FastAPI ()
351
+ self .active_connections : list [WebSocket ] = []
352
+ self .setup_routes ()
353
+
354
+ def setup_routes (self ):
355
+ """Configure the WebSocket endpoint to handle client connections."""
356
+ @self .app .websocket ("/" )
357
+ async def websocket_endpoint (websocket : WebSocket ):
358
+ await self .connect (websocket )
359
+ try :
360
+ while True :
361
+ message = await websocket .receive_text ()
362
+ await self .handle_message (websocket , message )
363
+ except WebSocketDisconnect :
364
+ self .disconnect (websocket )
365
+
366
+ async def connect (self , websocket : WebSocket ):
367
+ """
368
+ Accept a WebSocket connection and store it.
369
+
370
+ :param websocket: the WebSocket connection to accept.
371
+ :type websocket: webSocket
372
+ """
373
+ await websocket .accept ()
374
+ self .active_connections .append (websocket )
375
+
376
+ def disconnect (self , websocket : WebSocket ):
377
+ """
378
+ Handle WebSocket disconnection.
379
+
380
+ :param websocket: the WebSocket connection to remove.
381
+ :type websocket: webSocket
382
+ """
383
+ self .active_connections .remove (websocket )
384
+
385
+ async def handle_message (self , websocket : WebSocket , message : str ):
386
+ """
387
+ Handle messages received from WebSocket clients.
388
+
389
+ :param websocket: the WebSocket connection from which the message was received.
390
+ :type websocket: webSocket
391
+ :param message: the message received from the client.
392
+ :type message: str
393
+ """
394
+ try :
395
+ message = json .loads (message )
396
+ action = message ['action' ]
397
+ print (f"Server received action: { action } " )
398
+ payload = self .parse (message ['payload' ])
399
+
400
+ if action == "download" :
401
+ response = self ._handle_download ()
402
+ elif action == "upload" :
403
+ response = self ._handle_upload (payload )
404
+ elif action == "attribute_call" :
405
+ response = self ._handle_attribute_call (payload )
406
+ elif action == "attribute_type" :
407
+ response = self ._handle_attribute_type (payload )
408
+ else :
409
+ response = {"error" : f"Unknown action: { action } " }
410
+
411
+ await websocket .send_text (json .dumps (response ))
412
+ except Exception as e :
413
+ await websocket .send_text (json .dumps ({"error" : str (e )}))
414
+
415
+ def _handle_download (self ) -> dict :
416
+ """
417
+ Handle download requests.
418
+
419
+ :return: a response containing the exported model.
420
+ """
421
+ return {
422
+ "message" : "Download request received." ,
423
+ "payload" : self ._ps .export_model (),
424
+ }
425
+
426
+ def _handle_upload (self , payload : dict ) -> dict :
427
+ """
428
+ Handle upload requests.
429
+
430
+ :param payload: the payload containing the model data to upload.
431
+ :type payload: dict
432
+ :return: a response indicating that the upload was processed.
433
+ """
434
+ return {
435
+ "message" : "Upload request received." ,
436
+ "payload" : self ._ps .update_model (payload ["model" ]),
437
+ }
438
+
439
+ def _handle_attribute_call (self , payload : dict ) -> dict :
440
+ """
441
+ Handle attribute call requests.
442
+
443
+ :param payload: the payload containing the attribute call details.
444
+ :type payload: dict
445
+ :return: a response with the result of the attribute call.
446
+ """
447
+ result = self ._ps .execute_model (payload )
448
+ return {
449
+ "message" : "Attribute call executed." ,
450
+ "payload" : result if result else "The ML model has been updated in place." ,
451
+ }
452
+
453
+ def _handle_attribute_type (self , payload : dict ) -> dict :
454
+ """
455
+ Handle attribute type queries.
456
+
457
+ :param payload: the payload containing the attribute to query.
458
+ :type payload: dict
459
+ :return: a response with the attribute type and value.
460
+ """
461
+ is_callable , field_value = self ._ps .is_callable_attribute (payload )
462
+ return {
463
+ "message" : "Attribute type query executed." ,
464
+ "attribute type" : "method" if is_callable else "field" ,
465
+ "attribute value" : "" if is_callable else field_value ,
466
+ }
467
+
468
+ def parse (self , message : str ) -> dict :
469
+ """
470
+ Parse the encrypted and compressed message.
471
+
472
+ :param message: the encrypted and compressed message to parse.
473
+ :type message: str
474
+ :return: the decrypted and extracted version of the message.
475
+ """
476
+ return json .loads (
477
+ self ._ps ._compressor .extract (
478
+ self ._ps ._encryptor .decrypt (message )
479
+ )
480
+ )
481
+
482
+ def run (self ):
483
+ """Run the internal FastAPI server."""
484
+ uvicorn .run (self .app , host = self .host , port = self .port )
485
+
486
+
487
+ class CommunicationProtocol (Enum ):
488
+ """Communication protocol."""
489
+
490
+ REST = {
491
+ "CLIENT" : RESTClientCommunicator ,
492
+ "SERVER" : RESTServerCommunicator ,
493
+ }
494
+ WEBSOCKET = {
495
+ "CLIENT" : WebSocketClientCommunicator ,
496
+ "SERVER" : WebSocketServerCommunicator ,
497
+ }
0 commit comments