Skip to content

Commit 8467776

Browse files
authored
Fix/minor naming feedback (#129)
* add ability to import `PymiloClient`, `PymiloServer` and `Compression` directly from `streaming` package * convert Mode to an inner class + update non-private field attributes naming convention * update non-private field attributes naming convention * refactor testcases based on applied updates
1 parent 641a10d commit 8467776

File tree

6 files changed

+39
-37
lines changed

6 files changed

+39
-37
lines changed

pymilo/streaming/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
# -*- coding: utf-8 -*-
22
"""PyMilo ML Streaming."""
3+
from .pymilo_client import PymiloClient
4+
from .pymilo_server import PymiloServer
5+
from .compressor import Compression

pymilo/streaming/pymilo_client.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,15 @@
1111
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
1212

1313

14-
class Mode(Enum):
15-
"""fallback state of the PyMiloClient."""
16-
17-
LOCAL = 1
18-
DELEGATE = 2
19-
20-
2114
class PymiloClient:
2215
"""Facilitate working with the PyMilo server."""
2316

17+
class Mode(Enum):
18+
"""fallback state of the PyMiloClient."""
19+
20+
LOCAL = 1
21+
DELEGATE = 2
22+
2423
def __init__(
2524
self,
2625
model=None,
@@ -41,9 +40,9 @@ def __init__(
4140
:type server_url: str
4241
:return: an instance of the Pymilo PymiloClient class
4342
"""
44-
self._client_id = "0x_client_id"
45-
self._model_id = "0x_model_id"
46-
self._model = model
43+
self.model = model
44+
self.client_id = "0x_client_id"
45+
self.model_id = "0x_model_id"
4746
self._mode = mode
4847
self._compressor = compressor.value
4948
self._encryptor = DummyEncryptor()
@@ -69,7 +68,7 @@ def toggle_mode(self, mode=Mode.LOCAL):
6968
7069
:return: None
7170
"""
72-
if mode not in Mode.__members__.values():
71+
if mode not in PymiloClient.Mode.__members__.values():
7372
raise Exception(PYMILO_CLIENT_INVALID_MODE)
7473
if mode != self._mode:
7574
self._mode = mode
@@ -83,15 +82,15 @@ def download(self):
8382
serialized_model = self._communicator.download(
8483
self.encrypt_compress(
8584
{
86-
"client_id": self._client_id,
87-
"model_id": self._model_id,
85+
"client_id": self.client_id,
86+
"model_id": self.model_id,
8887
}
8988
)
9089
)
9190
if serialized_model is None:
9291
print(PYMILO_CLIENT_FAILED_TO_DOWNLOAD_REMOTE_MODEL)
9392
return
94-
self._model = Import(file_adr=None, json_dump=serialized_model).to_model()
93+
self.model = Import(file_adr=None, json_dump=serialized_model).to_model()
9594
print(PYMILO_CLIENT_MODEL_SYNCHED)
9695

9796
def upload(self):
@@ -103,9 +102,9 @@ def upload(self):
103102
succeed = self._communicator.upload(
104103
self.encrypt_compress(
105104
{
106-
"client_id": self._client_id,
107-
"model_id": self._model_id,
108-
"model": Export(self._model).to_json(),
105+
"client_id": self.client_id,
106+
"model_id": self.model_id,
107+
"model": Export(self.model).to_json(),
109108
}
110109
)
111110
)
@@ -123,18 +122,18 @@ def __getattr__(self, attribute):
123122
124123
:return: Any
125124
"""
126-
if self._mode == Mode.LOCAL:
127-
if attribute in dir(self._model):
128-
return getattr(self._model, attribute)
125+
if self._mode == PymiloClient.Mode.LOCAL:
126+
if attribute in dir(self.model):
127+
return getattr(self.model, attribute)
129128
else:
130129
raise AttributeError(PYMILO_CLIENT_INVALID_ATTRIBUTE)
131-
elif self._mode == Mode.DELEGATE:
130+
elif self._mode == PymiloClient.Mode.DELEGATE:
132131
gdst = GeneralDataStructureTransporter()
133132
response = self._communicator.attribute_type(
134133
self.encrypt_compress(
135134
{
136-
"client_id": self._client_id,
137-
"model_id": self._model_id,
135+
"client_id": self.client_id,
136+
"model_id": self.model_id,
138137
"attribute": attribute,
139138
}
140139
)
@@ -144,8 +143,8 @@ def __getattr__(self, attribute):
144143

145144
def relayer(*args, **kwargs):
146145
payload = {
147-
"client_id": self._client_id,
148-
"model_id": self._model_id,
146+
"client_id": self.client_id,
147+
"model_id": self.model_id,
149148
'attribute': attribute,
150149
'args': args,
151150
'kwargs': kwargs,

pymilo/streaming/pymilo_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, port=8000, compressor=Compression.NULL):
2424
self._model = None
2525
self._compressor = compressor.value
2626
self._encryptor = DummyEncryptor()
27-
self._communicator = RESTServerCommunicator(ps=self, port=port)
27+
self.communicator = RESTServerCommunicator(ps=self, port=port)
2828

2929
def export_model(self):
3030
"""

tests/test_ml_streaming/run_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import argparse
2-
from pymilo.streaming.compressor import Compression
3-
from pymilo.streaming.pymilo_server import PymiloServer
2+
from pymilo.streaming import Compression
3+
from pymilo.streaming import PymiloServer
44

55

66
def main():
77
parser = argparse.ArgumentParser(description='Run the Pymilo server with a specified compression method.')
88
parser.add_argument('--compression', type=str, choices=['NULL', 'GZIP', 'ZLIB', 'LZMA', 'BZ2'], default='NULL',
99
help='Specify the compression method (NULL, GZIP, ZLIB, LZMA, or BZ2). Default is NULL.')
1010
args = parser.parse_args()
11-
communicator = PymiloServer(compressor=Compression[args.compression])._communicator
11+
communicator = PymiloServer(compressor=Compression[args.compression]).communicator
1212
communicator.run()
1313

1414
if __name__ == '__main__':

tests/test_ml_streaming/scenarios/scenario1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import numpy as np
2+
from pymilo.streaming import Compression
3+
from pymilo.streaming import PymiloClient
24
from sklearn.metrics import mean_squared_error
35
from sklearn.linear_model import LinearRegression
4-
from pymilo.streaming.compressor import Compression
5-
from pymilo.streaming.pymilo_client import PymiloClient, Mode
66
from pymilo.utils.data_exporter import prepare_simple_regression_datasets
77

88

@@ -21,7 +21,7 @@ def scenario1(compression_method):
2121

2222
# 2.
2323
linear_regression.fit(x_train, y_train)
24-
client = PymiloClient(model=linear_regression, mode=Mode.LOCAL, compressor=Compression[compression_method])
24+
client = PymiloClient(model=linear_regression, mode=PymiloClient.Mode.LOCAL, compressor=Compression[compression_method])
2525

2626
# 3.
2727
result = client.predict(x_test)

tests/test_ml_streaming/scenarios/scenario2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import numpy as np
2+
from pymilo.streaming import Compression
3+
from pymilo.streaming import PymiloClient
24
from sklearn.metrics import mean_squared_error
35
from sklearn.linear_model import LinearRegression
4-
from pymilo.streaming.compressor import Compression
5-
from pymilo.streaming.pymilo_client import PymiloClient, Mode
66
from pymilo.utils.data_exporter import prepare_simple_regression_datasets
77

88

@@ -18,13 +18,13 @@ def scenario2(compression_method):
1818
# 1.
1919
x_train, y_train, x_test, y_test = prepare_simple_regression_datasets()
2020
linear_regression = LinearRegression()
21-
client = PymiloClient(model=linear_regression, mode=Mode.LOCAL, compressor=Compression[compression_method])
21+
client = PymiloClient(model=linear_regression, mode=PymiloClient.Mode.LOCAL, compressor=Compression[compression_method])
2222

2323
# 2.
2424
client.upload()
2525

2626
# 3.
27-
client.toggle_mode(Mode.DELEGATE)
27+
client.toggle_mode(PymiloClient.Mode.DELEGATE)
2828
client.fit(x_train, y_train)
2929
remote_field = client.coef_
3030

@@ -36,7 +36,7 @@ def scenario2(compression_method):
3636
client.download()
3737

3838
# 6.
39-
client.toggle_mode(mode=Mode.LOCAL)
39+
client.toggle_mode(mode=PymiloClient.Mode.LOCAL)
4040
local_field = client.coef_
4141
result = client.predict(x_test)
4242
mse_local = mean_squared_error(y_test, result)

0 commit comments

Comments
 (0)