Skip to content

Commit 8604c64

Browse files
authored
Add/websocket (#143)
* add url validation check in RESTClientCommunicator constructor * minor refactoring in RESTServerCommunicator constructor * WebSocketClientCommunicator implemented * WebSocketServerCommunicator implemented * refactor imports and update * ClientCommunicator Enum defined * ServerCommunicator Enum defined * `invalid url` and `websocket not connect` error messages added * url validation utility functions added * scenario1 generalized to support various communication medium protocols * scenario2 generalized to support various communication medium protocols * scenario3 generalized to support various communication medium protocols * `run_server` generalized to support various communication medium protocols * fixtures updated to support various communication medium protocols + testcases updated accordingly * PyMiloClient updated to support Communication Medium Protocol selection as input * `PyMiloServer` updated to support Communication Medium Protocol selection as input + minor refactorings to be more generalized * `websockets` added to streaming requirements * `websockets` added to dev requirements * remove secondary event loop creation * lightweighting ml streaming testcases * add public module docstring * `autopep8.sh` applied * remove un-used imports + increase sleep time to lower connection refuse errors in pytest * `CHANGELOG.md` updated * create even loop if it doesn't exist * lowercasing the letters for the starting of docstring :param , :type and :return sections * `CHANGELOG.md` updated * `CHANGELOG.md` updated * fix `ClientCommunicator` sudden override by the `ClientCommunicationProtocol` enum * combine protocol enums * I added `CommunicationProtocol` to the `__init__` file of the `streaming` module in order to make it more accessbile * `CHANGELOG.md` updated * `CHANGELOG.md` updated * fixing versions * remove scipy, pydantic * convert websocket req to >= * refactor check for socket close and add support for different `websockets` versions * update on `requirements.txt` * update on `requirements.txt` * `CHANGELOG.md` updated * remove trailing whitespaces
1 parent bfcfdc4 commit 8604c64

File tree

14 files changed

+450
-40
lines changed

14 files changed

+450
-40
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
66

77
## [Unreleased]
88
### Added
9+
- `is_socket_closed` function in `streaming.communicator.py`
10+
- `validate_http_url` function in `streaming.util.py`
11+
- `validate_websocket_url` function in `streaming.util.py`
12+
- `ML Streaming` WebSocket testcases
13+
- `CommunicationProtocol` Enum in `streaming.communicator.py`
14+
- `WebSocketClientCommunicator` class in `streaming.communicator.py`
15+
- `WebSocketServerCommunicator` class in `streaming.communicator.py`
916
- batch operation testcases
1017
- `batch_export` function in `pymilo/pymilo_obj.py`
1118
- `batch_import` function in `pymilo/pymilo_obj.py`
@@ -17,6 +24,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
1724
- PyMilo exception types added in `pymilo/exceptions/__init__.py`
1825
- PyMilo exception types added in `pymilo/__init__.py`
1926
### Changed
27+
- `communication_protocol` parameter added to `PyMiloClient` class
28+
- `communication_protocol` parameter added to `PyMiloServer` class
29+
- `ML Streaming` testcases updated to support protocol selection
2030
- `README.md` updated
2131
- Tests config modified
2232
- Cross decomposition params initialized in `pymilo_param`

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ uvicorn==0.32.0
55
fastapi==0.115.5
66
requests==2.32.3
77
pydantic>=1.5.0
8+
websockets==10.4
89
setuptools>=40.8.0
910
vulture>=1.0
1011
bandit>=1.5.1

pymilo/streaming/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .pymilo_client import PymiloClient
44
from .pymilo_server import PymiloServer
55
from .compressor import Compression
6+
from .communicator import CommunicationProtocol

pymilo/streaming/communicator.py

Lines changed: 311 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
# -*- coding: utf-8 -*-
2-
"""PyMilo RESTFull Communication Mediums."""
2+
"""PyMilo Communication Mediums."""
33
import json
4+
import asyncio
45
import uvicorn
56
import requests
7+
import websockets
8+
from enum import Enum
69
from pydantic import BaseModel
7-
from fastapi import FastAPI, Request
10+
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
811
from .interfaces import ClientCommunicator
12+
from .param import PYMILO_INVALID_URL, PYMILO_CLIENT_WEBSOCKET_NOT_CONNECTED
13+
from .util import validate_websocket_url, validate_http_url
914

1015

1116
class RESTClientCommunicator(ClientCommunicator):
@@ -19,6 +24,9 @@ def __init__(self, server_url):
1924
:type server_url: str
2025
:return: an instance of the Pymilo RESTClientCommunicator class
2126
"""
27+
is_valid, server_url = validate_http_url(server_url)
28+
if not is_valid:
29+
raise Exception(PYMILO_INVALID_URL)
2230
self._server_url = server_url
2331
self.session = requests.Session()
2432
retries = requests.adapters.Retry(
@@ -96,10 +104,10 @@ def __init__(
96104
:type port: int
97105
:return: an instance of the Pymilo RESTServerCommunicator class
98106
"""
99-
self.app = FastAPI()
107+
self._ps = ps
100108
self.host = host
101109
self.port = port
102-
self._ps = ps
110+
self.app = FastAPI()
103111
self.setup_routes()
104112

105113
def setup_routes(self):
@@ -188,3 +196,302 @@ def parse(self, body):
188196
def run(self):
189197
"""Run internal fastapi server."""
190198
uvicorn.run(self.app, host=self.host, port=self.port)
199+
200+
201+
class WebSocketClientCommunicator(ClientCommunicator):
202+
"""Facilitate working with the communication medium from the client side for the WebSocket protocol."""
203+
204+
def __init__(
205+
self,
206+
server_url: str = "ws://127.0.0.1:8000"
207+
):
208+
"""
209+
Initialize the WebSocketClientCommunicator instance.
210+
211+
:param server_url: the WebSocket server URL to connect to.
212+
:type server_url: str
213+
:return: an instance of the Pymilo WebSocketClientCommunicator class
214+
"""
215+
is_valid, url = validate_websocket_url(server_url)
216+
if not is_valid:
217+
raise Exception(PYMILO_INVALID_URL)
218+
self.server_url = url
219+
self.websocket = None
220+
self.connection_established = asyncio.Event() # Event to signal connection status
221+
# check for even loop existance
222+
if asyncio._get_running_loop() is None:
223+
self.loop = asyncio.new_event_loop()
224+
asyncio.set_event_loop(self.loop)
225+
else:
226+
self.loop = asyncio.get_event_loop()
227+
self.loop.run_until_complete(self.connect())
228+
229+
def is_socket_closed(self):
230+
"""
231+
Check if the WebSocket connection is closed.
232+
233+
:return: `True` if the WebSocket connection is closed or uninitialized, `False` otherwise.
234+
"""
235+
if self.websocket is None:
236+
return True
237+
elif hasattr(self.websocket, "closed"): # For older versions
238+
return self.websocket.closed
239+
elif hasattr(self.websocket, "state"): # For newer versions
240+
return self.websocket.state is websockets.protocol.State.CLOSED
241+
242+
async def connect(self):
243+
"""Establish a WebSocket connection with the server."""
244+
if self.is_socket_closed():
245+
self.websocket = await websockets.connect(self.server_url)
246+
print("Connected to the WebSocket server.")
247+
self.connection_established.set()
248+
249+
async def disconnect(self):
250+
"""Close the WebSocket connection."""
251+
if self.websocket:
252+
await self.websocket.close()
253+
254+
async def send_message(self, action: str, payload: dict) -> dict:
255+
"""
256+
Send a message to the WebSocket server.
257+
258+
:param action: the type of action to perform (e.g., 'download', 'upload').
259+
:type action: str
260+
:param payload: the payload associated with the action.
261+
:type payload: dict
262+
:return: the server's response as a JSON object.
263+
"""
264+
await self.connection_established.wait()
265+
266+
if self.is_socket_closed():
267+
raise RuntimeError(PYMILO_CLIENT_WEBSOCKET_NOT_CONNECTED)
268+
269+
message = json.dumps({"action": action, "payload": payload})
270+
await self.websocket.send(message)
271+
response = await self.websocket.recv()
272+
return json.loads(response)
273+
274+
def download(self, payload: dict) -> dict:
275+
"""
276+
Request the remote ML model to download.
277+
278+
:param payload: the payload for the download request.
279+
:type payload: dict
280+
:return: the downloaded model data.
281+
"""
282+
response = self.loop.run_until_complete(
283+
self.send_message("download", payload)
284+
)
285+
return response.get("payload")
286+
287+
def upload(self, payload: dict) -> bool:
288+
"""
289+
Upload the local ML model to the remote server.
290+
291+
:param payload: the payload for the upload request.
292+
:type payload: dict
293+
:return: true if the upload request is acknowledged.
294+
"""
295+
response = self.loop.run_until_complete(
296+
self.send_message("upload", payload)
297+
)
298+
return response.get("message") == "Upload request received."
299+
300+
def attribute_call(self, payload: dict) -> dict:
301+
"""
302+
Delegate the requested attribute call to the remote server.
303+
304+
:param payload: the payload containing attribute call details.
305+
:type payload: dict
306+
:return: the server's response to the attribute call.
307+
"""
308+
response = self.loop.run_until_complete(
309+
self.send_message("attribute_call", payload)
310+
)
311+
return response
312+
313+
def attribute_type(self, payload: dict) -> dict:
314+
"""
315+
Identify the attribute type of the requested attribute.
316+
317+
:param payload: the payload containing attribute type request.
318+
:type payload: dict
319+
:return: the server's response with the attribute type.
320+
"""
321+
response = self.loop.run_until_complete(
322+
self.send_message("attribute_type", payload)
323+
)
324+
return response
325+
326+
327+
class WebSocketServerCommunicator:
328+
"""Facilitate working with the communication medium from the server side for the WebSocket protocol."""
329+
330+
def __init__(
331+
self,
332+
ps,
333+
host: str = "127.0.0.1",
334+
port: int = 8000,
335+
):
336+
"""
337+
Initialize the WebSocketServerCommunicator instance.
338+
339+
:param ps: reference to the PyMilo server.
340+
:type ps: pymilo.streaming.PymiloServer
341+
:param host: the WebSocket server host address.
342+
:type host: str
343+
:param port: the WebSocket server port.
344+
:type port: int
345+
:return: an instance of the WebSocketServerCommunicator class.
346+
"""
347+
self._ps = ps
348+
self.host = host
349+
self.port = port
350+
self.app = FastAPI()
351+
self.active_connections: list[WebSocket] = []
352+
self.setup_routes()
353+
354+
def setup_routes(self):
355+
"""Configure the WebSocket endpoint to handle client connections."""
356+
@self.app.websocket("/")
357+
async def websocket_endpoint(websocket: WebSocket):
358+
await self.connect(websocket)
359+
try:
360+
while True:
361+
message = await websocket.receive_text()
362+
await self.handle_message(websocket, message)
363+
except WebSocketDisconnect:
364+
self.disconnect(websocket)
365+
366+
async def connect(self, websocket: WebSocket):
367+
"""
368+
Accept a WebSocket connection and store it.
369+
370+
:param websocket: the WebSocket connection to accept.
371+
:type websocket: webSocket
372+
"""
373+
await websocket.accept()
374+
self.active_connections.append(websocket)
375+
376+
def disconnect(self, websocket: WebSocket):
377+
"""
378+
Handle WebSocket disconnection.
379+
380+
:param websocket: the WebSocket connection to remove.
381+
:type websocket: webSocket
382+
"""
383+
self.active_connections.remove(websocket)
384+
385+
async def handle_message(self, websocket: WebSocket, message: str):
386+
"""
387+
Handle messages received from WebSocket clients.
388+
389+
:param websocket: the WebSocket connection from which the message was received.
390+
:type websocket: webSocket
391+
:param message: the message received from the client.
392+
:type message: str
393+
"""
394+
try:
395+
message = json.loads(message)
396+
action = message['action']
397+
print(f"Server received action: {action}")
398+
payload = self.parse(message['payload'])
399+
400+
if action == "download":
401+
response = self._handle_download()
402+
elif action == "upload":
403+
response = self._handle_upload(payload)
404+
elif action == "attribute_call":
405+
response = self._handle_attribute_call(payload)
406+
elif action == "attribute_type":
407+
response = self._handle_attribute_type(payload)
408+
else:
409+
response = {"error": f"Unknown action: {action}"}
410+
411+
await websocket.send_text(json.dumps(response))
412+
except Exception as e:
413+
await websocket.send_text(json.dumps({"error": str(e)}))
414+
415+
def _handle_download(self) -> dict:
416+
"""
417+
Handle download requests.
418+
419+
:return: a response containing the exported model.
420+
"""
421+
return {
422+
"message": "Download request received.",
423+
"payload": self._ps.export_model(),
424+
}
425+
426+
def _handle_upload(self, payload: dict) -> dict:
427+
"""
428+
Handle upload requests.
429+
430+
:param payload: the payload containing the model data to upload.
431+
:type payload: dict
432+
:return: a response indicating that the upload was processed.
433+
"""
434+
return {
435+
"message": "Upload request received.",
436+
"payload": self._ps.update_model(payload["model"]),
437+
}
438+
439+
def _handle_attribute_call(self, payload: dict) -> dict:
440+
"""
441+
Handle attribute call requests.
442+
443+
:param payload: the payload containing the attribute call details.
444+
:type payload: dict
445+
:return: a response with the result of the attribute call.
446+
"""
447+
result = self._ps.execute_model(payload)
448+
return {
449+
"message": "Attribute call executed.",
450+
"payload": result if result else "The ML model has been updated in place.",
451+
}
452+
453+
def _handle_attribute_type(self, payload: dict) -> dict:
454+
"""
455+
Handle attribute type queries.
456+
457+
:param payload: the payload containing the attribute to query.
458+
:type payload: dict
459+
:return: a response with the attribute type and value.
460+
"""
461+
is_callable, field_value = self._ps.is_callable_attribute(payload)
462+
return {
463+
"message": "Attribute type query executed.",
464+
"attribute type": "method" if is_callable else "field",
465+
"attribute value": "" if is_callable else field_value,
466+
}
467+
468+
def parse(self, message: str) -> dict:
469+
"""
470+
Parse the encrypted and compressed message.
471+
472+
:param message: the encrypted and compressed message to parse.
473+
:type message: str
474+
:return: the decrypted and extracted version of the message.
475+
"""
476+
return json.loads(
477+
self._ps._compressor.extract(
478+
self._ps._encryptor.decrypt(message)
479+
)
480+
)
481+
482+
def run(self):
483+
"""Run the internal FastAPI server."""
484+
uvicorn.run(self.app, host=self.host, port=self.port)
485+
486+
487+
class CommunicationProtocol(Enum):
488+
"""Communication protocol."""
489+
490+
REST = {
491+
"CLIENT": RESTClientCommunicator,
492+
"SERVER": RESTServerCommunicator,
493+
}
494+
WEBSOCKET = {
495+
"CLIENT": WebSocketClientCommunicator,
496+
"SERVER": WebSocketServerCommunicator,
497+
}

pymilo/streaming/param.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88
PYMILO_CLIENT_FAILED_TO_DOWNLOAD_REMOTE_MODEL = "PyMiloClient failed to download the remote ML model."
99

1010
PYMILO_SERVER_NON_EXISTENT_ATTRIBUTE = "The requested attribute doesn't exist in this model."
11+
PYMILO_INVALID_URL = "The given URL is not valid."
12+
PYMILO_CLIENT_WEBSOCKET_NOT_CONNECTED = "WebSocket is not connected."

0 commit comments

Comments
 (0)