Skip to content

Commit edcfba0

Browse files
authored
Mode Enum (#119)
* add `Mode` Enum to `pymilo_client.py` * update testcases accordingly * updaet according to Enum refactorings * add port to PyMiloServer init params * Update pymilo_client.py
1 parent 55d93f0 commit edcfba0

File tree

4 files changed

+23
-14
lines changed

4 files changed

+23
-14
lines changed

pymilo/streaming/pymilo_client.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
1+
from enum import Enum
12
from .encryptor import DummyEncryptor
23
from .compressor import DummyCompressor
34
from ..pymilo_obj import Export, Import
45
from .communicator import RESTClientCommunicator
56
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
7+
8+
9+
class Mode(Enum):
10+
"""fallback state of the PyMiloClient."""
11+
12+
LOCAL = 1
13+
DELEGATE = 2
14+
15+
616
class PymiloClient:
717

818
def __init__(
919
self,
1020
model=None,
11-
mode="LOCAL",
21+
mode=Mode.LOCAL,
1222
server="http://127.0.0.1",
1323
port= 8000
1424
):
@@ -22,9 +32,8 @@ def __init__(
2232
server_url="{}:{}".format(server, port)
2333
)
2434

25-
def toggle_mode(self, mode="LOCAL"):
26-
mode = mode.upper()
27-
if mode not in ["LOCAL", "DELEGATE"]:
35+
def toggle_mode(self, mode=Mode.LOCAL):
36+
if mode not in Mode.__members__.values():
2837
raise Exception("Invalid mode, the given mode should be either `LOCAL`[default] or `DELEGATE`.")
2938
self._mode = mode
3039

@@ -52,12 +61,12 @@ def upload(self):
5261
print("Local model upload failed.")
5362

5463
def __getattr__(self, attribute):
55-
if self._mode == "LOCAL":
64+
if self._mode == Mode.LOCAL:
5665
if attribute in dir(self._model):
5766
return getattr(self._model, attribute)
5867
else:
5968
raise AttributeError("This attribute doesn't exist either in PymiloClient or the inner ML model.")
60-
elif self._mode == "DELEGATE":
69+
elif self._mode == Mode.DELEGATE:
6170
gdst = GeneralDataStructureTransporter()
6271
def relayer(*args, **kwargs):
6372
print(f"Method '{attribute}' called with args: {args} and kwargs: {kwargs}")

pymilo/streaming/pymilo_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77

88
class PymiloServer:
99

10-
def __init__(self):
10+
def __init__(self, port=8000):
1111
self._model = None
1212
self._compressor = DummyCompressor()
1313
self._encryptor = DummyEncryptor()
14-
self._communicator = RESTServerCommunicator(ps=self)
14+
self._communicator = RESTServerCommunicator(ps=self, port=port)
1515
self._communicator.run()
1616

1717
def export_model(self):

tests/test_ml_streaming/scenarios/scenario1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from sklearn.metrics import mean_squared_error
33
from sklearn.linear_model import LinearRegression
4-
from pymilo.streaming.pymilo_client import PymiloClient
4+
from pymilo.streaming.pymilo_client import PymiloClient, Mode
55
from pymilo.utils.data_exporter import prepare_simple_regression_datasets
66

77

@@ -20,7 +20,7 @@ def scenario1():
2020

2121
# 2.
2222
linear_regression.fit(x_train, y_train)
23-
client = PymiloClient(model=linear_regression, mode="LOCAL")
23+
client = PymiloClient(model=linear_regression, mode=Mode.LOCAL)
2424

2525
# 3.
2626
result = client.predict(x_test)

tests/test_ml_streaming/scenarios/scenario2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from sklearn.metrics import mean_squared_error
33
from sklearn.linear_model import LinearRegression
4-
from pymilo.streaming.pymilo_client import PymiloClient
4+
from pymilo.streaming.pymilo_client import PymiloClient, Mode
55
from pymilo.utils.data_exporter import prepare_simple_regression_datasets
66

77

@@ -17,13 +17,13 @@ def scenario2():
1717
# 1.
1818
x_train, y_train, x_test, y_test = prepare_simple_regression_datasets()
1919
linear_regression = LinearRegression()
20-
client = PymiloClient(model=linear_regression, mode="LOCAL")
20+
client = PymiloClient(model=linear_regression, mode=Mode.LOCAL)
2121

2222
# 2.
2323
client.upload()
2424

2525
# 3.
26-
client.toggle_mode(mode="DELEGATE")
26+
client.toggle_mode(Mode.DELEGATE)
2727
client.fit(x_train, y_train)
2828

2929
# 4.
@@ -34,7 +34,7 @@ def scenario2():
3434
client.download()
3535

3636
# 6.
37-
client.toggle_mode(mode="LOCAL")
37+
client.toggle_mode(mode=Mode.LOCAL)
3838
result = client.predict(x_test)
3939
mse_local = mean_squared_error(y_test, result)
4040
return np.abs(mse_server-mse_local)

0 commit comments

Comments
 (0)