Skip to content

Commit 53d39ac

Browse files
authored
Feature/cli (#174)
* add art * add __main__ (for CLI) to `setup.py` * add overview to `pymilo_param.py` * add url regex * add sklearn supported categories dictionary * add `print_supported_ml_models` for CLI * add `pymilo_help` for CLI * develop `get_sklearn_class` function * implement CLI for pymilo(+pymilo server) * apply `autopep8.sh` * `CHANGELOG.md` updated * add docstring to `main` function * update docstring in `get_sklearn_class` function * apply minor refactoring * refactor to run `communicator.run()` once * `CHANGELOG.md` updated * apply refactoring to `if-else`'s
1 parent d2f5131 commit 53d39ac

File tree

7 files changed

+190
-3
lines changed

7 files changed

+190
-3
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
66

77
## [Unreleased]
88
### Added
9+
- `CLI` handler
10+
- `print_supported_ml_models` function in `pymilo_func.py`
11+
- `pymilo_help` function in `pymilo_func.py`
12+
- `SKLEARN_SUPPORTED_CATEGORIES` in `pymilo_param.py`
13+
- `OVERVIEW` in `pymilo_param.py`
14+
- `get_sklearn_class` in `utils.util.py`
915
### Changed
1016
- `to_pymilo_issue` function in `PymiloException`
1117
- `valid_url_valid_file` testcase added in `test_exceptions.py`

pymilo/__main__.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# -*- coding: utf-8 -*-
2+
"""PyMilo main."""
3+
import re
4+
import argparse
5+
from art import tprint
6+
from .pymilo_param import PYMILO_VERSION, URL_REGEX
7+
from .pymilo_func import print_supported_ml_models, pymilo_help
8+
from .pymilo_obj import Import
9+
from .utils.util import get_sklearn_class
10+
11+
ml_streaming_support = True
12+
try:
13+
from .streaming import PymiloServer, Compression, CommunicationProtocol
14+
except BaseException:
15+
ml_streaming_support = False
16+
17+
18+
def main():
19+
"""
20+
CLI main function.
21+
22+
:return: None
23+
"""
24+
parser = argparse.ArgumentParser(description='Run the Pymilo server with a specified compression method.')
25+
parser.add_argument(
26+
'--compression',
27+
type=str,
28+
choices=['NULL', 'GZIP', 'ZLIB', 'LZMA', 'BZ2'],
29+
default='NULL',
30+
help='Specify the compression method (NULL, GZIP, ZLIB, LZMA, or BZ2). Default is NULL.'
31+
)
32+
parser.add_argument(
33+
'--port',
34+
type=int,
35+
default=8000,
36+
help='Specify PyMiloServer port number',
37+
metavar="",
38+
)
39+
parser.add_argument(
40+
'--protocol',
41+
type=str,
42+
choices=['REST', 'WEBSOCKET'],
43+
default='REST',
44+
help='Specify the communication protocol (REST or WEBSOCKET). Default is REST.'
45+
)
46+
parser.add_argument(
47+
'--load',
48+
type=str,
49+
default=None,
50+
help='the `load` command specifies the path to the JSON file of the previously exported ML model by PyMilo.',
51+
metavar="",
52+
)
53+
parser.add_argument(
54+
'--init',
55+
type=str,
56+
default=None,
57+
help='the `init` command specifies the ML model to initialize the PyMilo Server with.',
58+
metavar="",
59+
)
60+
parser.add_argument(
61+
'--bare',
62+
default=False,
63+
action='store_true',
64+
help='The `bare` command starts the PyMilo Server without an internal ML model.',
65+
)
66+
parser.add_argument('--version', action='store_true', default=False, help='PyMilo version')
67+
parser.add_argument('-v', action='store_true', default=False, help='PyMilo version')
68+
args = parser.parse_args()
69+
if args.version or args.v:
70+
print(PYMILO_VERSION)
71+
return
72+
if not ml_streaming_support:
73+
print("Error: ML Streaming is not installed.")
74+
print("To install ML Streaming, run the following command:")
75+
print("pip install pymilo[streaming]")
76+
print("For more information, visit the PyMilo README at https://github.com/openscilab/pymilo")
77+
tprint("PyMilo")
78+
tprint("V:" + PYMILO_VERSION)
79+
pymilo_help()
80+
parser.print_help()
81+
return
82+
run_ps = False
83+
_model = None
84+
_port = args.port
85+
_compressor = Compression[args.compression]
86+
_communication_protocol = CommunicationProtocol[args.protocol]
87+
if args.load:
88+
path = args.load
89+
run_ps = True
90+
_model = Import(url=path) if re.match(URL_REGEX, path) else Import(file_adr=path)
91+
elif args.init:
92+
model_name = args.init
93+
model_class = get_sklearn_class(model_name)
94+
if model_class is None:
95+
print(
96+
"The given ML model name is neither valid nor supported, use the list below: \n{print_supported_ml_models}")
97+
print_supported_ml_models()
98+
return
99+
run_ps = True
100+
_model = model_class()
101+
elif args.bare:
102+
run_ps = True
103+
_model = model_class()
104+
if not run_ps:
105+
tprint("PyMilo")
106+
tprint("V:" + PYMILO_VERSION)
107+
pymilo_help()
108+
parser.print_help()
109+
else:
110+
PymiloServer(
111+
model=_model,
112+
port=_port,
113+
compressor=_compressor,
114+
communication_protocol=_communication_protocol,
115+
).communicator.run()
116+
117+
if __name__ == '__main__':
118+
main()

pymilo/pymilo_func.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from .chains.ensemble_chain import get_transporter
77
from .transporters.transporter import Command
8+
from .pymilo_param import SKLEARN_SUPPORTED_CATEGORIES, NOT_SUPPORTED, OVERVIEW
89

910

1011
def get_sklearn_version():
@@ -62,3 +63,28 @@ def compare_model_outputs(exported_output,
6263
return False # TODO: throw exception
6364
total_error += np.abs(imported_output[key] - exported_output[key])
6465
return np.abs(total_error) < epsilon_error
66+
67+
68+
def print_supported_ml_models():
69+
"""
70+
Print the supported sklearn ML models categorized by type.
71+
72+
:return: None
73+
"""
74+
print("Supported Machine Learning Models:")
75+
for category, table in SKLEARN_SUPPORTED_CATEGORIES.items():
76+
print(f"**{category}**:")
77+
for model_name in table:
78+
if table[model_name] != NOT_SUPPORTED:
79+
print(f"- {model_name}")
80+
81+
82+
def pymilo_help():
83+
"""
84+
Print PyMilo details.
85+
86+
:return: None
87+
"""
88+
print(OVERVIEW)
89+
print("Repo : https://github.com/openscilab/pymilo")
90+
print("Webpage : https://openscilab.com/\n")

pymilo/pymilo_param.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,16 @@
8282
except BaseException:
8383
pass
8484

85-
85+
OVERVIEW = """
86+
PyMilo is an open source Python package that provides a simple, efficient, and safe way for users to export pre-trained machine learning models in a transparent way.
87+
"""
8688
PYMILO_VERSION = "1.1"
8789
NOT_SUPPORTED = "NOT_SUPPORTED"
8890
PYMILO_VERSION_DOES_NOT_EXIST = "Corrupted JSON file, `pymilo_version` doesn't exist in this file."
8991
UNEQUAL_PYMILO_VERSIONS = "warning: Installed PyMilo version differs from the PyMilo version used to create the JSON file."
9092
UNEQUAL_SKLEARN_VERSIONS = "warning: Installed Scikit version differs from the Scikit version used to create the JSON file and it may prevent PyMilo from transporting seamlessly."
9193
INVALID_IMPORT_INIT_PARAMS = "Invalid input parameters, you should either pass a valid file_adr or a json_dump or a url to initiate Import class."
94+
URL_REGEX = r'^(http|https)://[a-zA-Z0-9.-_]+\.[a-zA-Z]{2,}(/\S*)?$'
9295
DOWNLOAD_MODEL_FAILED = "Failed to download the JSON file, Server didn't respond."
9396
INVALID_DOWNLOADED_MODEL = "The downloaded content is not a valid JSON file."
9497
BATCH_IMPORT_INVALID_DIRECTORY = "The given directory does not exist."
@@ -276,3 +279,15 @@
276279
"ENSEMBLE": "exported_ensembles",
277280
"CROSS_DECOMPOSITION": "exported_cross_decomposition",
278281
}
282+
283+
SKLEARN_SUPPORTED_CATEGORIES = {
284+
"LINEAR_MODEL": SKLEARN_LINEAR_MODEL_TABLE,
285+
"NEURAL_NETWORK": SKLEARN_NEURAL_NETWORK_TABLE,
286+
"DECISION_TREE": SKLEARN_DECISION_TREE_TABLE,
287+
"CLUSTERING": SKLEARN_CLUSTERING_TABLE,
288+
"NAIVE_BAYES": SKLEARN_NAIVE_BAYES_TABLE,
289+
"SVM": SKLEARN_SVM_TABLE,
290+
"NEIGHBORS": SKLEARN_NEIGHBORS_TABLE,
291+
"ENSEMBLE": SKLEARN_ENSEMBLE_TABLE,
292+
"CROSS_DECOMPOSITION": SKLEARN_CROSS_DECOMPOSITION_TABLE,
293+
}

pymilo/utils/util.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import requests
44
import importlib
55
from inspect import signature
6-
from ..pymilo_param import DOWNLOAD_MODEL_FAILED, INVALID_DOWNLOADED_MODEL
6+
from ..pymilo_param import DOWNLOAD_MODEL_FAILED, INVALID_DOWNLOADED_MODEL, SKLEARN_SUPPORTED_CATEGORIES
77

88

99
def get_sklearn_type(model):
@@ -165,3 +165,19 @@ def download_model(url):
165165
return response.json()
166166
except ValueError:
167167
raise Exception(INVALID_DOWNLOADED_MODEL)
168+
169+
170+
def get_sklearn_class(model_name):
171+
"""
172+
Return the sklearn class of the requested model name.
173+
174+
:param model_name: model name
175+
:type model_name: str
176+
177+
:return: sklearn ML model class
178+
"""
179+
for _, category_models in SKLEARN_SUPPORTED_CATEGORIES.items():
180+
if model_name in category_models:
181+
return category_models[model_name]
182+
# todo raise exception
183+
return None

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1+
art>=1.8
12
numpy>=1.9.0
3+
requests>=2.0.0
24
scikit-learn>=0.22.2
35
scipy>=0.19.1
4-
requests>=2.0.0

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,9 @@ def read_description():
7979
'Topic :: Scientific/Engineering :: Physics',
8080
],
8181
license='MIT',
82+
entry_points={
83+
'console_scripts': [
84+
'pymilo = pymilo.__main__:main',
85+
]
86+
}
8287
)

0 commit comments

Comments
 (0)