Skip to content

Commit 07ce47a

Browse files
authored
Refactoring (#180)
* remove comment * refactor hardcoded string and put in `pymilo_param.py` * call `to_model` on pymilo `Import` object * replace `run_server.py` with pymilo CLI * refactor `test_streaming` test cases to use pymilo CLI * run `autopep8.sh` * `CHANGELOG.md` updated * update seconds for `time.sleep` * replace load from url to load from local file, set both timeouts to 10s * remove returning `server_proc` * put back the `run_server.py` script * update testcases to use `run_server.py` script instead of direct PyMilo CLI
1 parent e3902dd commit 07ce47a

File tree

8 files changed

+57
-18
lines changed

8 files changed

+57
-18
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
1414
- `OVERVIEW` in `pymilo_param.py`
1515
- `get_sklearn_class` in `utils.util.py`
1616
### Changed
17+
- `ML Streaming` testcases modified to use PyMilo CLI
1718
- `to_pymilo_issue` function in `PymiloException`
1819
- `valid_url_valid_file` testcase added in `test_exceptions.py`
1920
- `valid_url_valid_file` function in `import_exceptions.py`

pymilo/__main__.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import re
44
import argparse
55
from art import tprint
6-
from .pymilo_param import PYMILO_VERSION, URL_REGEX
6+
from .pymilo_param import (
7+
PYMILO_VERSION,
8+
URL_REGEX,
9+
CLI_MORE_INFO,
10+
CLI_UNKNOWN_MODEL,
11+
CLI_ML_STREAMING_NOT_INSTALLED,
12+
)
713
from .pymilo_func import print_supported_ml_models, pymilo_help
814
from .pymilo_obj import Import
915
from .utils.util import get_sklearn_class
@@ -70,10 +76,8 @@ def main():
7076
print(PYMILO_VERSION)
7177
return
7278
if not ml_streaming_support:
73-
print("Error: ML Streaming is not installed.")
74-
print("To install ML Streaming, run the following command:")
75-
print("pip install pymilo[streaming]")
76-
print("For more information, visit the PyMilo README at https://github.com/openscilab/pymilo")
79+
print(CLI_ML_STREAMING_NOT_INSTALLED)
80+
print(CLI_MORE_INFO)
7781
tprint("PyMilo")
7882
tprint("V:" + PYMILO_VERSION)
7983
pymilo_help()
@@ -88,19 +92,17 @@ def main():
8892
path = args.load
8993
run_ps = True
9094
_model = Import(url=path) if re.match(URL_REGEX, path) else Import(file_adr=path)
95+
_model = _model.to_model()
9196
elif args.init:
9297
model_name = args.init
9398
model_class = get_sklearn_class(model_name)
9499
if model_class is None:
95-
print(
96-
"The given ML model name is neither valid nor supported, use the list below: \n{print_supported_ml_models}")
97-
print_supported_ml_models()
100+
print(f"{CLI_UNKNOWN_MODEL}\n{print_supported_ml_models()}")
98101
return
99102
run_ps = True
100103
_model = model_class()
101104
elif args.bare:
102105
run_ps = True
103-
_model = model_class()
104106
if not run_ps:
105107
tprint("PyMilo")
106108
tprint("V:" + PYMILO_VERSION)
@@ -114,5 +116,6 @@ def main():
114116
communication_protocol=_communication_protocol,
115117
).communicator.run()
116118

119+
117120
if __name__ == '__main__':
118121
main()

pymilo/chains/ensemble_chain.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,10 @@ def deserialize(self, ensemble, is_inner_model=False):
170170
setattr(raw_model, item, data[item])
171171
return raw_model
172172

173+
173174
ensemble_chain = EnsembleModelChain(ENSEMBLE_CHAIN, SKLEARN_ENSEMBLE_TABLE)
174175

176+
175177
def get_transporter(model):
176178
"""
177179
Get associated transporter for the given ML model.
@@ -188,6 +190,7 @@ def get_transporter(model):
188190
else:
189191
return get_concrete_transporter(model)
190192

193+
191194
def serialize_possible_ml_model(possible_ml_model):
192195
"""
193196
Check whether the given object is a ML model and if it is, serialize it.
@@ -209,6 +212,7 @@ def serialize_possible_ml_model(possible_ml_model):
209212
else:
210213
return False, possible_ml_model
211214

215+
212216
def deserialize_possible_ml_model(possible_serialized_ml_model):
213217
"""
214218
Check whether the given object is previously serialized ML model and if it is, deserialize it back to the associated ML model.
@@ -226,6 +230,7 @@ def deserialize_possible_ml_model(possible_serialized_ml_model):
226230
else:
227231
return False, possible_serialized_ml_model
228232

233+
229234
def serialize_models_in_ndarray(ndarray_instance):
230235
"""
231236
Serialize the ml models inside the given ndarray.
@@ -268,6 +273,7 @@ def serialize_models_in_ndarray(ndarray_instance):
268273
'pymiloed-data-structure': 'numpy.ndarray'
269274
}
270275

276+
271277
def deserialize_models_in_ndarray(serialized_ndarray):
272278
"""
273279
Deserializes possible ML models within the given ndarray instance.

pymilo/chains/linear_model_chain.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,10 @@ def deserialize(self, linear_model, is_inner_model=False):
7878
setattr(raw_model, item, data[item])
7979
return raw_model
8080

81+
8182
linear_chain = LinearModelChain(LINEAR_MODEL_CHAIN, SKLEARN_LINEAR_MODEL_TABLE)
8283

84+
8385
def is_deserialized_linear_model(content):
8486
"""
8587
Check if the given content is a previously serialized model by Pymilo's Export or not.

pymilo/pymilo_param.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@
9696
INVALID_DOWNLOADED_MODEL = "The downloaded content is not a valid JSON file."
9797
BATCH_IMPORT_INVALID_DIRECTORY = "The given directory does not exist."
9898

99+
CLI_ML_STREAMING_NOT_INSTALLED = """ML Streaming is not installed.
100+
To install ML Streaming, run the following command:
101+
pip install pymilo[streaming]"""
102+
CLI_MORE_INFO = "For more information, visit the PyMilo README at https://github.com/openscilab/pymilo"
103+
CLI_UNKNOWN_MODEL = "The provided ML model name is either invalid or unsupported."
104+
99105
SKLEARN_LINEAR_MODEL_TABLE = {
100106
"DummyRegressor": dummy.DummyRegressor,
101107
"DummyClassifier": dummy.DummyClassifier,

pymilo/utils/util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,5 +179,4 @@ def get_sklearn_class(model_name):
179179
for _, category_models in SKLEARN_SUPPORTED_CATEGORIES.items():
180180
if model_name in category_models:
181181
return category_models[model_name]
182-
# todo raise exception
183182
return None

tests/test_ml_streaming/run_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ def main():
4848
communicator.run()
4949

5050
if __name__ == '__main__':
51-
main()
51+
main()

tests/test_ml_streaming/test_streaming.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@
1313
params=["NULL", "GZIP", "ZLIB", "LZMA", "BZ2"])
1414
def prepare_bare_server(request):
1515
compression_method = request.param
16+
# Using PyMilo direct CLI
17+
# server_proc = subprocess.Popen(
18+
# [
19+
# executable,
20+
# "-m", "pymilo",
21+
# "--compression", compression_method,
22+
# "--protocol", "REST",
23+
# "--port", "8000",
24+
# "--bare",
25+
# ],
26+
# )
1627
path = os.path.join(
1728
os.getcwd(),
1829
"tests",
@@ -28,7 +39,7 @@ def prepare_bare_server(request):
2839
],
2940
)
3041
time.sleep(10)
31-
yield (server_proc, compression_method, "REST")
42+
yield (compression_method, "REST")
3243
server_proc.terminate()
3344

3445

@@ -38,7 +49,18 @@ def prepare_bare_server(request):
3849
def prepare_ml_server(request):
3950
communication_protocol = request.param
4051
compression_method = "ZLIB"
41-
print(communication_protocol)
52+
# Using PyMilo direct CLI
53+
# server_proc = subprocess.Popen(
54+
# [
55+
# executable,
56+
# "-m", "pymilo",
57+
# "--compression", compression_method,
58+
# "--protocol", communication_protocol,
59+
# "--port", "9000",
60+
# "--load", os.path.join(os.getcwd(), "tests", "test_exceptions", "valid_jsons", "linear_regression.json")
61+
# # "--load", "https://raw.githubusercontent.com/openscilab/pymilo/main/tests/test_exceptions/valid_jsons/linear_regression.json",
62+
# ],
63+
# )
4264
path = os.path.join(
4365
os.getcwd(),
4466
"tests",
@@ -54,21 +76,21 @@ def prepare_ml_server(request):
5476
"--init",
5577
],
5678
)
57-
time.sleep(5)
58-
yield (server_proc, compression_method, communication_protocol)
79+
time.sleep(10)
80+
yield (compression_method, communication_protocol)
5981
server_proc.terminate()
6082

6183

6284
def test1(prepare_bare_server):
63-
_, compression_method, communication_protocol = prepare_bare_server
85+
compression_method, communication_protocol = prepare_bare_server
6486
assert scenario1(compression_method, communication_protocol) == 0
6587

6688

6789
def test2(prepare_bare_server):
68-
_, compression_method, communication_protocol = prepare_bare_server
90+
compression_method, communication_protocol = prepare_bare_server
6991
assert scenario2(compression_method, communication_protocol) == 0
7092

7193

7294
def test3(prepare_ml_server):
73-
_, compression_method, communication_protocol = prepare_ml_server
95+
compression_method, communication_protocol = prepare_ml_server
7496
assert scenario3(compression_method, communication_protocol) == 0

0 commit comments

Comments
 (0)