Skip to content

Commit 9aef619

Browse files
authored
Add/compression methods (#125)
* refactorings + update `parse` to return python obj * add support of `gzip`, `zlib`, `lzma` and `bz2` compressing methods * add compressor to PyMiloClient init params * add docstring to compile method * rename `compile` to `encrypt_compress` * add compressor to PyMiloServer init params * `autopep8.sh` applied * update `run_server` script to set compression method * update scenarios to set compression method in PyMiloClient * raise concrete Exception type * combine enc + comp from client side * major refactorings + run ml streaming tests per compression method * `autopep8.sh` applied
1 parent 9b7798a commit 9aef619

File tree

8 files changed

+181
-43
lines changed

8 files changed

+181
-43
lines changed

pymilo/streaming/communicator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""PyMilo RESTFull Communication Mediums."""
3+
import json
34
import uvicorn
45
import requests
56
from pydantic import BaseModel
@@ -178,9 +179,9 @@ def parse(self, body):
178179
:type body: str
179180
:return: the extracted decrypted version
180181
"""
181-
return self._ps._compressor.extract(
182-
self._ps._encryptor.decrypt(
183-
body
182+
return json.loads(
183+
self._ps._compressor.extract(
184+
self._ps._encryptor.decrypt(body)
184185
)
185186
)
186187

pymilo/streaming/compressor.py

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
# -*- coding: utf-8 -*-
22
"""Implementations of Compressor interface."""
3-
from .interfaces import Compressor
3+
import gzip
4+
import zlib
5+
import lzma
6+
import bz2
7+
import json
8+
import base64
9+
from enum import Enum
10+
from pymilo.streaming.interfaces import Compressor
411

512

613
class DummyCompressor(Compressor):
@@ -9,9 +16,99 @@ class DummyCompressor(Compressor):
916
@staticmethod
1017
def compress(payload):
1118
"""Compress the given payload in a dummy way, simply just return it (no compression applied)."""
12-
return payload
19+
return payload if isinstance(payload, str) else json.dumps(payload)
1320

1421
@staticmethod
1522
def extract(payload):
1623
"""Extract the given payload in a dummy way, simply just return it (no Extraction applied)."""
1724
return payload
25+
26+
27+
class GZIPCompressor(Compressor):
28+
"""GZIP implementation of the Compressor interface."""
29+
30+
@staticmethod
31+
def compress(payload):
32+
"""Compress the given payload using gzip."""
33+
if isinstance(payload, str):
34+
data = payload.encode('utf-8')
35+
else:
36+
data = json.dumps(payload).encode('utf-8')
37+
compressed_data = gzip.compress(data)
38+
return base64.b64encode(compressed_data).decode('utf-8')
39+
40+
@staticmethod
41+
def extract(payload):
42+
"""Extract the given payload using gzip."""
43+
data = base64.b64decode(payload)
44+
return gzip.decompress(data).decode('utf-8')
45+
46+
47+
class ZLIBCompressor(Compressor):
48+
"""ZLIB implementation of the Compressor interface."""
49+
50+
@staticmethod
51+
def compress(payload):
52+
"""Compress the given payload using zlib."""
53+
if isinstance(payload, str):
54+
data = payload.encode('utf-8')
55+
else:
56+
data = json.dumps(payload).encode('utf-8')
57+
compressed_data = zlib.compress(data)
58+
return base64.b64encode(compressed_data).decode('utf-8')
59+
60+
@staticmethod
61+
def extract(payload):
62+
"""Extract the given payload using zlib."""
63+
data = base64.b64decode(payload)
64+
return zlib.decompress(data).decode('utf-8')
65+
66+
67+
class LZMACompressor(Compressor):
68+
"""LZMA implementation of the Compressor interface."""
69+
70+
@staticmethod
71+
def compress(payload):
72+
"""Compress the given payload using lzma."""
73+
if isinstance(payload, str):
74+
data = payload.encode('utf-8')
75+
else:
76+
data = json.dumps(payload).encode('utf-8')
77+
compressed_data = lzma.compress(data)
78+
return base64.b64encode(compressed_data).decode('utf-8')
79+
80+
@staticmethod
81+
def extract(payload):
82+
"""Extract the given payload using lzma."""
83+
data = base64.b64decode(payload)
84+
return lzma.decompress(data).decode('utf-8')
85+
86+
87+
class BZ2Compressor(Compressor):
88+
"""BZ2 implementation of the Compressor interface."""
89+
90+
@staticmethod
91+
def compress(payload):
92+
"""Compress the given payload using bz2."""
93+
if isinstance(payload, str):
94+
data = payload.encode('utf-8')
95+
else:
96+
data = json.dumps(payload).encode('utf-8')
97+
compressed_data = bz2.compress(data)
98+
return base64.b64encode(compressed_data).decode('utf-8')
99+
100+
@staticmethod
101+
def extract(payload):
102+
"""Extract the given payload using bz2."""
103+
data = base64.b64decode(payload)
104+
return bz2.decompress(data).decode('utf-8')
105+
106+
107+
class Compression(Enum):
108+
"""Compression method used in end to end communication."""
109+
110+
NULL = DummyCompressor
111+
GZIP = GZIPCompressor
112+
ZLIB = ZLIBCompressor
113+
LZMA = LZMACompressor
114+
BZ2 = BZ2Compressor

pymilo/streaming/pymilo_client.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""PyMiloClient for RESTFull Protocol."""
33
from enum import Enum
44
from .encryptor import DummyEncryptor
5-
from .compressor import DummyCompressor
5+
from .compressor import Compression
66
from ..pymilo_obj import Export, Import
77
from .param import PYMILO_CLIENT_INVALID_MODE, PYMILO_CLIENT_MODEL_SYNCHED, \
88
PYMILO_CLIENT_LOCAL_MODEL_UPLOADED, PYMILO_CLIENT_LOCAL_MODEL_UPLOAD_FAILED, \
@@ -25,6 +25,7 @@ def __init__(
2525
self,
2626
model=None,
2727
mode=Mode.LOCAL,
28+
compressor=Compression.NULL,
2829
server_url="http://127.0.0.1:8000",
2930
):
3031
"""
@@ -34,6 +35,8 @@ def __init__(
3435
:type model: Any
3536
:param mode: the mode in which PymiloClient should work, either LOCAL mode or DELEGATE
3637
:type mode: str (LOCAL|DELEGATE)
38+
:param compressor: the compression method to be used in client-server communications
39+
:type compressor: pymilo.streaming.compressor.Compression
3740
:param server_url: the url to which PyMilo Server listens
3841
:type server_url: str
3942
:return: an instance of the Pymilo PymiloClient class
@@ -42,10 +45,24 @@ def __init__(
4245
self._model_id = "0x_model_id"
4346
self._model = model
4447
self._mode = mode
45-
self._compressor = DummyCompressor()
48+
self._compressor = compressor.value
4649
self._encryptor = DummyEncryptor()
4750
self._communicator = RESTClientCommunicator(server_url)
4851

52+
def encrypt_compress(self, body):
53+
"""
54+
Compress and Encrypt body payload.
55+
56+
:param body: body payload of the request
57+
:type body: dict
58+
:return: the compressed and encrypted version of the body payload
59+
"""
60+
return self._encryptor.encrypt(
61+
self._compressor.compress(
62+
body
63+
)
64+
)
65+
4966
def toggle_mode(self, mode=Mode.LOCAL):
5067
"""
5168
Toggle the PyMiloClient mode, either from LOCAL to DELEGATE or vice versa.
@@ -63,10 +80,14 @@ def download(self):
6380
6481
:return: None
6582
"""
66-
serialized_model = self._communicator.download({
67-
"client_id": self._client_id,
68-
"model_id": self._model_id
69-
})
83+
serialized_model = self._communicator.download(
84+
self.encrypt_compress(
85+
{
86+
"client_id": self._client_id,
87+
"model_id": self._model_id,
88+
}
89+
)
90+
)
7091
if serialized_model is None:
7192
print(PYMILO_CLIENT_FAILED_TO_DOWNLOAD_REMOTE_MODEL)
7293
return
@@ -79,11 +100,15 @@ def upload(self):
79100
80101
:return: None
81102
"""
82-
succeed = self._communicator.upload({
83-
"client_id": self._client_id,
84-
"model_id": self._model_id,
85-
"model": Export(self._model).to_json(),
86-
})
103+
succeed = self._communicator.upload(
104+
self.encrypt_compress(
105+
{
106+
"client_id": self._client_id,
107+
"model_id": self._model_id,
108+
"model": Export(self._model).to_json(),
109+
}
110+
)
111+
)
87112
if succeed:
88113
print(PYMILO_CLIENT_LOCAL_MODEL_UPLOADED)
89114
else:
@@ -106,14 +131,12 @@ def __getattr__(self, attribute):
106131
elif self._mode == Mode.DELEGATE:
107132
gdst = GeneralDataStructureTransporter()
108133
response = self._communicator.attribute_type(
109-
self._encryptor.encrypt(
110-
self._compressor.compress(
111-
{
112-
"client_id": self._client_id,
113-
"model_id": self._model_id,
114-
"attribute": attribute,
115-
}
116-
)
134+
self.encrypt_compress(
135+
{
136+
"client_id": self._client_id,
137+
"model_id": self._model_id,
138+
"attribute": attribute,
139+
}
117140
)
118141
)
119142
if response["attribute type"] == "field":
@@ -130,10 +153,8 @@ def relayer(*args, **kwargs):
130153
payload["args"] = gdst.serialize(payload, "args", None)
131154
payload["kwargs"] = gdst.serialize(payload, "kwargs", None)
132155
result = self._communicator.attribute_call(
133-
self._encryptor.encrypt(
134-
self._compressor.compress(
135-
payload
136-
)
156+
self.encrypt_compress(
157+
payload
137158
)
138159
)
139160
return gdst.deserialize(result, "payload", None)

pymilo/streaming/pymilo_server.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
"""PyMiloServer for RESTFull protocol."""
33
from ..pymilo_obj import Export, Import
4+
from .compressor import Compression
45
from .encryptor import DummyEncryptor
5-
from .compressor import DummyCompressor
66
from .communicator import RESTServerCommunicator
77
from .param import PYMILO_SERVER_NON_EXISTENT_ATTRIBUTE
88
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
@@ -11,16 +11,18 @@
1111
class PymiloServer:
1212
"""Facilitate streaming the ML models."""
1313

14-
def __init__(self, port=8000):
14+
def __init__(self, port=8000, compressor=Compression.NULL):
1515
"""
1616
Initialize the Pymilo PymiloServer instance.
1717
1818
:param port: the port to which PyMiloServer listens
1919
:type port: int
20+
:param compressor: the compression method to be used in client-server communications
21+
:type compressor: pymilo.streaming.compressor.Compression
2022
:return: an instance of the PymiloServer class
2123
"""
2224
self._model = None
23-
self._compressor = DummyCompressor()
25+
self._compressor = compressor.value
2426
self._encryptor = DummyEncryptor()
2527
self._communicator = RESTServerCommunicator(ps=self, port=port)
2628

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
1+
import argparse
2+
from pymilo.streaming.compressor import Compression
13
from pymilo.streaming.pymilo_server import PymiloServer
24

3-
communicator = PymiloServer()._communicator
4-
communicator.run()
5+
6+
def main():
7+
parser = argparse.ArgumentParser(description='Run the Pymilo server with a specified compression method.')
8+
parser.add_argument('--compression', type=str, choices=['NULL', 'GZIP', 'ZLIB', 'LZMA', 'BZ2'], default='NULL',
9+
help='Specify the compression method (NULL, GZIP, ZLIB, LZMA, or BZ2). Default is NULL.')
10+
args = parser.parse_args()
11+
communicator = PymiloServer(compressor=Compression[args.compression])._communicator
12+
communicator.run()
13+
14+
if __name__ == '__main__':
15+
main()

tests/test_ml_streaming/scenarios/scenario1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import numpy as np
22
from sklearn.metrics import mean_squared_error
33
from sklearn.linear_model import LinearRegression
4+
from pymilo.streaming.compressor import Compression
45
from pymilo.streaming.pymilo_client import PymiloClient, Mode
56
from pymilo.utils.data_exporter import prepare_simple_regression_datasets
67

78

8-
def scenario1():
9+
def scenario1(compression_method):
910
# 1. create model in local
1011
# 2. train model in local
1112
# 3. calculate mse before streaming
@@ -20,7 +21,7 @@ def scenario1():
2021

2122
# 2.
2223
linear_regression.fit(x_train, y_train)
23-
client = PymiloClient(model=linear_regression, mode=Mode.LOCAL)
24+
client = PymiloClient(model=linear_regression, mode=Mode.LOCAL, compressor=Compression[compression_method])
2425

2526
# 3.
2627
result = client.predict(x_test)

tests/test_ml_streaming/scenarios/scenario2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import numpy as np
22
from sklearn.metrics import mean_squared_error
33
from sklearn.linear_model import LinearRegression
4+
from pymilo.streaming.compressor import Compression
45
from pymilo.streaming.pymilo_client import PymiloClient, Mode
56
from pymilo.utils.data_exporter import prepare_simple_regression_datasets
67

78

8-
def scenario2():
9+
def scenario2(compression_method):
910
# 1. create model in local
1011
# 2. upload model to server
1112
# 3. train model in server
@@ -17,7 +18,7 @@ def scenario2():
1718
# 1.
1819
x_train, y_train, x_test, y_test = prepare_simple_regression_datasets()
1920
linear_regression = LinearRegression()
20-
client = PymiloClient(model=linear_regression, mode=Mode.LOCAL)
21+
client = PymiloClient(model=linear_regression, mode=Mode.LOCAL, compressor=Compression[compression_method])
2122

2223
# 2.
2324
client.upload()

tests/test_ml_streaming/test_streaming.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,30 @@
66
from scenarios.scenario1 import scenario1
77
from scenarios.scenario2 import scenario2
88

9-
@pytest.fixture(scope="session")
10-
def prepare_server():
9+
@pytest.fixture(scope="session", params=["NULL", "GZIP", "ZLIB", "LZMA", "BZ2"])
10+
def prepare_server(request):
11+
compression_method = request.param
1112
path = os.path.join(
1213
os.getcwd(),
1314
"tests",
1415
"test_ml_streaming",
15-
"run_server.py"
16+
"run_server.py",
1617
)
1718
server_proc = subprocess.Popen(
1819
[
1920
executable,
2021
path,
22+
"--compression", compression_method
2123
],
2224
)
2325
time.sleep(2)
24-
yield server_proc
26+
yield (server_proc, compression_method)
2527
server_proc.terminate()
2628

2729
def test1(prepare_server):
28-
assert scenario1() == 0
30+
_, compression_method = prepare_server
31+
assert scenario1(compression_method) == 0
2932

3033
def test2(prepare_server):
31-
assert scenario2() == 0
34+
_, compression_method = prepare_server
35+
assert scenario2(compression_method) == 0

0 commit comments

Comments
 (0)