Skip to content

Commit 9b7798a

Browse files
authored
Refactorings (#124)
* autopep8.sh + minor refactorings * add pydantic to `dev` & `streaming` requirements * minor refactorings * refactor hardcoded strings to `param.py` * add docstring to `streaming.param.py` * increase retry to 10 times * add `attribute_type` endpoint to handle just field attribute requests * handle case of field attribute request (non-callable attributes retrieval request) in PyMiloClient * handle case of field attribute request (non-callable attributes retrieval request) in PyMiloServer * add field attribute retrieval to check to scenario2 * `autopep8.sh` applied
1 parent 7cd8243 commit 9b7798a

File tree

8 files changed

+125
-49
lines changed

8 files changed

+125
-49
lines changed

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ scipy>=0.19.1
44
uvicorn==0.30.6
55
fastapi==0.112.1
66
requests==2.32.3
7+
pydantic>=1.5.0
78
setuptools>=40.8.0
89
vulture>=1.0
910
bandit>=1.5.1

pymilo/streaming/communicator.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
"""PyMilo RESTFull Communication Mediums."""
33
import uvicorn
44
import requests
5-
from fastapi import FastAPI, Request
65
from pydantic import BaseModel
6+
from fastapi import FastAPI, Request
77
from .interfaces import ClientCommunicator
88

99

@@ -21,7 +21,7 @@ def __init__(self, server_url):
2121
self._server_url = server_url
2222
self.session = requests.Session()
2323
retries = requests.adapters.Retry(
24-
total=5,
24+
total=10,
2525
backoff_factor=0.1,
2626
status_forcelist=[500, 502, 503, 504]
2727
)
@@ -34,30 +34,45 @@ def download(self, payload):
3434
3535
:param payload: download request payload
3636
:type payload: dict
37-
:return: response of pymilo server
37+
:return: string serialized model
3838
"""
39-
return self.session.get(url=self._server_url + "/download/", json=payload, timeout=5)
39+
response = self.session.get(url=self._server_url + "/download/", json=payload, timeout=5)
40+
if response.status_code != 200:
41+
return None
42+
return response.json()["payload"]
4043

4144
def upload(self, payload):
4245
"""
4346
Upload the local ML model to the remote server.
4447
4548
:param payload: upload request payload
4649
:type payload: dict
47-
:return: response of pymilo server
50+
:return: True if upload was successful, False otherwise
4851
"""
49-
return self.session.post(url=self._server_url + "/upload/", json=payload, timeout=5)
52+
response = self.session.post(url=self._server_url + "/upload/", json=payload, timeout=5)
53+
return response.status_code == 200
5054

5155
def attribute_call(self, payload):
5256
"""
5357
Delegate the requested attribute call to the remote server.
5458
5559
:param payload: attribute call request payload
5660
:type payload: dict
57-
:return: response of pymilo server
61+
:return: json-encoded response of pymilo server
5862
"""
59-
return self.session.post(url=self._server_url + "/attribute_call/", json=payload, timeout=5)
63+
response = self.session.post(url=self._server_url + "/attribute_call/", json=payload, timeout=5)
64+
return response.json()
6065

66+
def attribute_type(self, payload):
67+
"""
68+
Identify the attribute type of the requested attribute.
69+
70+
:param payload: attribute type request payload
71+
:type payload: dict
72+
:return: response of pymilo server
73+
"""
74+
response = self.session.post(url=self._server_url + "/attribute_type/", json=payload, timeout=5)
75+
return response.json()
6176

6277

6378
class RESTServerCommunicator():
@@ -68,7 +83,7 @@ def __init__(
6883
ps,
6984
host: str = "127.0.0.1",
7085
port: int = 8000,
71-
):
86+
):
7287
"""
7388
Initialize the Pymilo RESTServerCommunicator instance.
7489
@@ -91,15 +106,21 @@ def setup_routes(self):
91106
class StandardPayload(BaseModel):
92107
client_id: str
93108
model_id: str
109+
94110
class DownloadPayload(StandardPayload):
95111
pass
112+
96113
class UploadPayload(StandardPayload):
97114
model: str
98-
class AttributePayload(StandardPayload):
115+
116+
class AttributeCallPayload(StandardPayload):
99117
attribute: str
100118
args: list
101119
kwargs: dict
102120

121+
class AttributeTypePayload(StandardPayload):
122+
attribute: str
123+
103124
@self.app.get("/download/")
104125
async def download(request: Request):
105126
body = await request.json()
@@ -126,14 +147,29 @@ async def upload(request: Request):
126147
async def attribute_call(request: Request):
127148
body = await request.json()
128149
body = self.parse(body)
129-
payload = AttributePayload(**body)
130-
message = "/attribute_call request from client: {} for model: {}".format(payload.client_id, payload.model_id)
150+
payload = AttributeCallPayload(**body)
151+
message = "/attribute_call request from client: {} for model: {}".format(
152+
payload.client_id, payload.model_id)
131153
result = self._ps.execute_model(payload)
132154
return {
133155
"message": message,
134156
"payload": result if result is not None else "The ML model has been updated in place."
135157
}
136158

159+
@self.app.post("/attribute_type/")
160+
async def attribute_type(request: Request):
161+
body = await request.json()
162+
body = self.parse(body)
163+
payload = AttributeTypePayload(**body)
164+
message = "/attribute_type request from client: {} for model: {}".format(
165+
payload.client_id, payload.model_id)
166+
is_callable, field_value = self._ps.is_callable_attribute(payload)
167+
return {
168+
"message": message,
169+
"attribute type": "method" if is_callable else "field",
170+
"attribute value": "" if is_callable else field_value,
171+
}
172+
137173
def parse(self, body):
138174
"""
139175
Parse the compressed encrypted body of the request.

pymilo/streaming/param.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# -*- coding: utf-8 -*-
2+
"""Streaming Parameters and constants."""
3+
PYMILO_CLIENT_INVALID_MODE = "Invalid mode, the given mode should be either `LOCAL`[default] or `DELEGATE`."
4+
PYMILO_CLIENT_MODEL_SYNCHED = "PyMiloClient synched the local ML model with the remote one successfully."
5+
PYMILO_CLIENT_LOCAL_MODEL_UPLOADED = "PyMiloClient uploaded the local model successfully."
6+
PYMILO_CLIENT_LOCAL_MODEL_UPLOAD_FAILED = "PyMiloClient failed to upload the local model."
7+
PYMILO_CLIENT_INVALID_ATTRIBUTE = "This attribute doesn't exist in either PymiloClient or the inner ML model."
8+
PYMILO_CLIENT_FAILED_TO_DOWNLOAD_REMOTE_MODEL = "PyMiloClient failed to download the remote ML model."
9+
10+
PYMILO_SERVER_NON_EXISTENT_ATTRIBUTE = "The requested attribute doesn't exist in this model."

pymilo/streaming/pymilo_client.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from .encryptor import DummyEncryptor
55
from .compressor import DummyCompressor
66
from ..pymilo_obj import Export, Import
7+
from .param import PYMILO_CLIENT_INVALID_MODE, PYMILO_CLIENT_MODEL_SYNCHED, \
8+
PYMILO_CLIENT_LOCAL_MODEL_UPLOADED, PYMILO_CLIENT_LOCAL_MODEL_UPLOAD_FAILED, \
9+
PYMILO_CLIENT_INVALID_ATTRIBUTE, PYMILO_CLIENT_FAILED_TO_DOWNLOAD_REMOTE_MODEL
710
from .communicator import RESTClientCommunicator
811
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
912

@@ -22,20 +25,17 @@ def __init__(
2225
self,
2326
model=None,
2427
mode=Mode.LOCAL,
25-
server="http://127.0.0.1",
26-
port= 8000
27-
):
28+
server_url="http://127.0.0.1:8000",
29+
):
2830
"""
2931
Initialize the Pymilo PymiloClient instance.
3032
3133
:param model: the ML model PyMiloClient wrapped around
3234
:type model: Any
3335
:param mode: the mode in which PymiloClient should work, either LOCAL mode or DELEGATE
3436
:type mode: str (LOCAL|DELEGATE)
35-
:param server: the url to which PyMilo Server listens
36-
:type server: str
37-
:param port: the port to which PyMilo Server listens
38-
:type port: int
37+
:param server_url: the url to which PyMilo Server listens
38+
:type server_url: str
3939
:return: an instance of the Pymilo PymiloClient class
4040
"""
4141
self._client_id = "0x_client_id"
@@ -44,10 +44,7 @@ def __init__(
4444
self._mode = mode
4545
self._compressor = DummyCompressor()
4646
self._encryptor = DummyEncryptor()
47-
self._communicator = RESTClientCommunicator(
48-
server_url="{}:{}".format(server, port)
49-
)
50-
47+
self._communicator = RESTClientCommunicator(server_url)
5148

5249
def toggle_mode(self, mode=Mode.LOCAL):
5350
"""
@@ -56,41 +53,41 @@ def toggle_mode(self, mode=Mode.LOCAL):
5653
:return: None
5754
"""
5855
if mode not in Mode.__members__.values():
59-
raise Exception("Invalid mode, the given mode should be either `LOCAL`[default] or `DELEGATE`.")
60-
self._mode = mode
56+
raise Exception(PYMILO_CLIENT_INVALID_MODE)
57+
if mode != self._mode:
58+
self._mode = mode
6159

6260
def download(self):
6361
"""
6462
Request for the remote ML model to download.
6563
6664
:return: None
6765
"""
68-
response = self._communicator.download({
69-
"client_id": self._client_id,
66+
serialized_model = self._communicator.download({
67+
"client_id": self._client_id,
7068
"model_id": self._model_id
7169
})
72-
if response.status_code != 200:
73-
print("Remote model download failed.")
74-
print("Remote model downloaded successfully.")
75-
serialized_model = response.json()["payload"]
70+
if serialized_model is None:
71+
print(PYMILO_CLIENT_FAILED_TO_DOWNLOAD_REMOTE_MODEL)
72+
return
7673
self._model = Import(file_adr=None, json_dump=serialized_model).to_model()
77-
print("Local model updated successfully.")
74+
print(PYMILO_CLIENT_MODEL_SYNCHED)
7875

7976
def upload(self):
8077
"""
8178
Upload the local ML model to the remote server.
8279
8380
:return: None
8481
"""
85-
response = self._communicator.upload({
86-
"client_id": self._client_id,
82+
succeed = self._communicator.upload({
83+
"client_id": self._client_id,
8784
"model_id": self._model_id,
8885
"model": Export(self._model).to_json(),
8986
})
90-
if response.status_code == 200:
91-
print("Local model uploaded successfully.")
87+
if succeed:
88+
print(PYMILO_CLIENT_LOCAL_MODEL_UPLOADED)
9289
else:
93-
print("Local model upload failed.")
90+
print(PYMILO_CLIENT_LOCAL_MODEL_UPLOAD_FAILED)
9491

9592
def __getattr__(self, attribute):
9693
"""
@@ -105,18 +102,31 @@ def __getattr__(self, attribute):
105102
if attribute in dir(self._model):
106103
return getattr(self._model, attribute)
107104
else:
108-
raise AttributeError("This attribute doesn't exist in either PymiloClient or the inner ML model.")
105+
raise AttributeError(PYMILO_CLIENT_INVALID_ATTRIBUTE)
109106
elif self._mode == Mode.DELEGATE:
110107
gdst = GeneralDataStructureTransporter()
108+
response = self._communicator.attribute_type(
109+
self._encryptor.encrypt(
110+
self._compressor.compress(
111+
{
112+
"client_id": self._client_id,
113+
"model_id": self._model_id,
114+
"attribute": attribute,
115+
}
116+
)
117+
)
118+
)
119+
if response["attribute type"] == "field":
120+
return gdst.deserialize(response, "attribute value", None)
121+
111122
def relayer(*args, **kwargs):
112-
print(f"Method '{attribute}' called with args: {args} and kwargs: {kwargs}")
113123
payload = {
114124
"client_id": self._client_id,
115125
"model_id": self._model_id,
116126
'attribute': attribute,
117127
'args': args,
118128
'kwargs': kwargs,
119-
}
129+
}
120130
payload["args"] = gdst.serialize(payload, "args", None)
121131
payload["kwargs"] = gdst.serialize(payload, "kwargs", None)
122132
result = self._communicator.attribute_call(
@@ -125,7 +135,6 @@ def relayer(*args, **kwargs):
125135
payload
126136
)
127137
)
128-
).json()
138+
)
129139
return gdst.deserialize(result, "payload", None)
130-
relayer.__doc__ = getattr(self._model.__class__, attribute).__doc__
131140
return relayer

pymilo/streaming/pymilo_server.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# -*- coding: utf-8 -*-
22
"""PyMiloServer for RESTFull protocol."""
33
from ..pymilo_obj import Export, Import
4-
from .compressor import DummyCompressor
54
from .encryptor import DummyEncryptor
5+
from .compressor import DummyCompressor
66
from .communicator import RESTServerCommunicator
7+
from .param import PYMILO_SERVER_NON_EXISTENT_ATTRIBUTE
78
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
89

910

@@ -13,7 +14,7 @@ class PymiloServer:
1314
def __init__(self, port=8000):
1415
"""
1516
Initialize the Pymilo PymiloServer instance.
16-
17+
1718
:param port: the port to which PyMiloServer listens
1819
:type port: int
1920
:return: an instance of the PymiloServer class
@@ -53,7 +54,7 @@ def execute_model(self, request):
5354
attribute = request.attribute
5455
retrieved_attribute = getattr(self._model, attribute, None)
5556
if retrieved_attribute is None:
56-
raise Exception("The requested attribute doesn't exist in this model.")
57+
raise Exception(PYMILO_SERVER_NON_EXISTENT_ATTRIBUTE)
5758
arguments = {
5859
'args': request.args,
5960
'kwargs': request.kwargs
@@ -65,3 +66,18 @@ def execute_model(self, request):
6566
self._model = output
6667
return None
6768
return gdst.serialize({'output': output}, 'output', None)
69+
70+
def is_callable_attribute(self, request):
71+
"""
72+
Check whether the requested attribute is callable or not.
73+
74+
:param request: request obj containing requested attribute to check it's type
75+
:type request: obj
76+
:return: True if it is callable False otherwise
77+
"""
78+
attribute = request.attribute
79+
retrieved_attribute = getattr(self._model, attribute, None)
80+
if callable(retrieved_attribute):
81+
return True, None
82+
else:
83+
return False, GeneralDataStructureTransporter().serialize({'output': retrieved_attribute}, 'output', None)

pymilo/transporters/randomstate_transporter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .transporter import AbstractTransporter
55
from ..utils.util import check_str_in_iterable
66

7+
78
class RandomStateTransporter(AbstractTransporter):
89
"""Customized PyMilo Transporter developed to handle RandomState field."""
910

@@ -31,8 +32,8 @@ def serialize(self, data, key, model_type):
3132
inner_random_state.get_state()[2],
3233
inner_random_state.get_state()[3],
3334
inner_random_state.get_state()[4],
34-
),
35-
}
35+
),
36+
}
3637
return data[key]
3738

3839
def deserialize(self, data, key, model_type):

streaming-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
uvicorn>=0.14.0
22
fastapi>=0.68.0
3-
requests>=2.0.0
3+
requests>=2.0.0
4+
pydantic>=1.5.0

tests/test_ml_streaming/scenarios/scenario2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def scenario2():
2525
# 3.
2626
client.toggle_mode(Mode.DELEGATE)
2727
client.fit(x_train, y_train)
28+
remote_field = client.coef_
2829

2930
# 4.
3031
result = client.predict(x_test)
@@ -35,6 +36,7 @@ def scenario2():
3536

3637
# 6.
3738
client.toggle_mode(mode=Mode.LOCAL)
39+
local_field = client.coef_
3840
result = client.predict(x_test)
3941
mse_local = mean_squared_error(y_test, result)
40-
return np.abs(mse_server-mse_local)
42+
return np.abs(mse_server-mse_local) + np.abs(np.sum(local_field-remote_field))

0 commit comments

Comments
 (0)