Skip to content

Commit 466908c

Browse files
author
Ceyda Cinarel
committed
update streamlit & add checks
1 parent ea0eb60 commit 466908c

File tree

3 files changed

+101
-59
lines changed

3 files changed

+101
-59
lines changed

setup.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
"""
44
from setuptools import find_packages, setup
55

6-
dependencies = ["streamlit>=0.76", "click>=7.1.2", "httpx>=0.16.0"]
6+
dependencies = ["streamlit==0.81.1", "click>=7.1.2", "httpx>=0.16.0"]
77

88
setup(
99
name="torchserve_dashboard",
1010
version="v0.2.5",
1111
url="https://github.com/cceyda/torchserve-dashboard",
1212
license="Apache Software License 2.0",
1313
author="Ceyda Cinarel",
14-
author_email="snu.ceyda@gmail.com",
14+
author_email="[email protected].com",
1515
description="Torchserve dashboard using Streamlit",
1616
long_description=__doc__,
1717
packages=find_packages(exclude=["tests","assets"]),
@@ -25,23 +25,12 @@
2525
""",
2626
classifiers=[
2727
# As from http://pypi.python.org/pypi?%3Aaction=list_classifiers
28-
# 'Development Status :: 1 - Planning',
29-
# 'Development Status :: 2 - Pre-Alpha',
30-
# 'Development Status :: 3 - Alpha',
3128
"Development Status :: 4 - Beta",
32-
# 'Development Status :: 5 - Production/Stable',
33-
# 'Development Status :: 6 - Mature',
34-
# 'Development Status :: 7 - Inactive',
35-
"Environment :: Console",
3629
"Intended Audience :: Developers",
37-
"License :: OSI Approved :: BSD License",
38-
"Operating System :: POSIX",
30+
"License :: OSI Approved :: Apache Software License",
3931
"Operating System :: MacOS",
4032
"Operating System :: Unix",
4133
"Operating System :: Microsoft :: Windows",
42-
"Programming Language :: Python",
43-
"Programming Language :: Python :: 2",
4434
"Programming Language :: Python :: 3",
45-
"Topic :: Software Development :: Libraries :: Python Modules",
4635
],
4736
)

torchserve_dashboard/api.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ def raise_on_not200(response):
1212
st.write("There was an error!")
1313
st.write(response)
1414

15-
1615
client = httpx.Client(timeout=1000, event_hooks={"response": [raise_on_not200]})
1716

1817

@@ -22,12 +21,19 @@ def start_torchserve(model_store, config_path, log_location=None, metrics_locati
2221
for x in ENVIRON_WHITELIST:
2322
if x in env:
2423
new_env[x]=env[x]
25-
2624
if log_location:
2725
new_env["LOG_LOCATION"]=log_location
2826
if metrics_location:
2927
new_env["METRICS_LOCATION"]=metrics_location
30-
if os.path.exists(model_store) and os.path.exists(config_path):
28+
if not os.path.isdir(metrics_location):
29+
os.makedirs(metrics_location, exist_ok=True)
30+
if not os.path.isdir(log_location):
31+
os.makedirs(log_location, exist_ok=True)
32+
if not os.path.exists(model_store):
33+
return "Can't find model store path"
34+
elif not os.path.exists(config_path):
35+
return "Can't find configuration path"
36+
else:
3137
torchserve_cmd = f"torchserve --start --ncs --model-store {model_store} --ts-config {config_path}"
3238
subprocess.Popen(
3339
torchserve_cmd.split(" "),

torchserve_dashboard/dash.py

Lines changed: 89 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,30 @@
1717

1818
parser = argparse.ArgumentParser(description="Torchserve dashboard")
1919

20-
parser.add_argument("--model_store", default=None, help="Directory where your models are stored")
21-
parser.add_argument("--config_path", default="./default.torchserve.properties", help="Torchserve config path")
22-
parser.add_argument("--log_location", default="./", help="Passed as environment variable LOG_LOCATION to Torchserve")
23-
parser.add_argument("--metrics_location", default="./", help="Passed as environment variable METRICS_LOCATION to Torchserve")
20+
parser.add_argument(
21+
"--model_store", default=None, help="Directory where your models are stored"
22+
)
23+
parser.add_argument(
24+
"--config_path",
25+
default="./default.torchserve.properties",
26+
help="Torchserve config path",
27+
)
28+
parser.add_argument(
29+
"--log_location",
30+
default="./logs/",
31+
help="Passed as environment variable LOG_LOCATION to Torchserve",
32+
)
33+
parser.add_argument(
34+
"--metrics_location",
35+
default="./logs/metrics/",
36+
help="Passed as environment variable METRICS_LOCATION to Torchserve",
37+
)
2438
try:
2539
args = parser.parse_args()
2640
except SystemExit as e:
2741
os._exit(e.code)
2842

29-
st.title("Torchserve Management Dashboard")
43+
st.title("Torchserve Management Dashboard")
3044

3145
M_API = "http://127.0.0.1:8081"
3246
model_store = args.model_store
@@ -65,6 +79,7 @@ def last_res():
6579
def get_model_store():
6680
return os.listdir(model_store)
6781

82+
6883
# As a design choice I'm leaving config_path,log_location,metrics_location non-editable from the UI as a semi-security measure (maybe?:/)
6984
##########Sidebar##########
7085
st.sidebar.markdown(f"## Help")
@@ -76,21 +91,23 @@ def get_model_store():
7691

7792
start = st.sidebar.button("Start Torchserve")
7893
if start:
79-
last_res()[0]= tsa.start_torchserve(model_store, config_path, log_location, metrics_location)
94+
last_res()[0] = tsa.start_torchserve(
95+
model_store, config_path, log_location, metrics_location
96+
)
8097
rerun()
8198

8299
stop = st.sidebar.button("Stop Torchserve")
83100
if stop:
84101
last_res()[0] = tsa.stop_torchserve()
85102
rerun()
86103

87-
loaded_models = tsa.get_loaded_models(M_API)
88-
if loaded_models:
89-
loaded_models_names = [m["modelName"] for m in loaded_models["models"]]
104+
torchserve_status = tsa.get_loaded_models(M_API)
105+
if torchserve_status:
106+
loaded_models_names = [m["modelName"] for m in torchserve_status["models"]]
90107
else:
91108
st.header("Torchserve is down...")
92109
st.sidebar.subheader("Loaded models")
93-
st.sidebar.write(loaded_models)
110+
st.sidebar.write(torchserve_status)
94111

95112
stored_models = get_model_store()
96113
st.sidebar.subheader("Available models")
@@ -104,23 +121,33 @@ def get_model_store():
104121
st.write(config)
105122
st.markdown("[configuration docs](https://pytorch.org/serve/configuration.html)")
106123

107-
if loaded_models:
124+
if torchserve_status:
108125

109126
with st.beta_expander(label="Register a model", expanded=False):
110127

111-
st.markdown("# Register a model [(docs)](https://pytorch.org/serve/management_api.html#register-a-model)")
128+
st.markdown(
129+
"# Register a model [(docs)](https://pytorch.org/serve/management_api.html#register-a-model)"
130+
)
112131
placeholder = st.empty()
113-
mar_path = placeholder.selectbox("Choose mar file *", [default_key] + stored_models, index=0)
132+
mar_path = placeholder.selectbox(
133+
"Choose mar file *", [default_key] + stored_models, index=0
134+
)
114135
# mar_path = os.path.join(model_store,mar_path)
115136
p = st.checkbox("or use another path")
116137
if p:
117138
mar_path = placeholder.text_input("Input mar file path*")
118139
model_name = st.text_input(label="Model name (overrides predefined)")
119140
col1, col2 = st.beta_columns(2)
120141
batch_size = col1.number_input(label="batch_size", value=0, min_value=0, step=1)
121-
max_batch_delay = col2.number_input(label="max_batch_delay", value=0, min_value=0, step=100)
122-
initial_workers = col1.number_input(label="initial_workers", value=1, min_value=0, step=1)
123-
response_timeout = col2.number_input(label="response_timeout", value=0, min_value=0, step=100)
142+
max_batch_delay = col2.number_input(
143+
label="max_batch_delay", value=0, min_value=0, step=100
144+
)
145+
initial_workers = col1.number_input(
146+
label="initial_workers", value=1, min_value=0, step=1
147+
)
148+
response_timeout = col2.number_input(
149+
label="response_timeout", value=0, min_value=0, step=100
150+
)
124151
handler = col1.text_input(label="handler")
125152
runtime = col2.text_input(label="runtime")
126153

@@ -143,19 +170,21 @@ def get_model_store():
143170
rerun()
144171
else:
145172
st.write(":octagonal_sign: Fill the required fileds!")
146-
147-
148173

149174
with st.beta_expander(label="Remove a model", expanded=False):
150175

151176
st.header("Remove a model")
152-
model_name = st.selectbox("Choose model to remove", [default_key] + loaded_models_names, index=0)
177+
model_name = st.selectbox(
178+
"Choose model to remove", [default_key] + loaded_models_names, index=0
179+
)
153180
if model_name != default_key:
154181
default_version = tsa.get_model(M_API, model_name)[0]["modelVersion"]
155182
st.write(f"default version {default_version}")
156183
versions = tsa.get_model(M_API, model_name, list_all=True)
157184
versions = [m["modelVersion"] for m in versions]
158-
version = st.selectbox("Choose version to remove", [default_key] + versions, index=0)
185+
version = st.selectbox(
186+
"Choose version to remove", [default_key] + versions, index=0
187+
)
159188
proceed = st.button("Remove")
160189
if proceed:
161190
if model_name != default_key and version != default_key:
@@ -164,52 +193,70 @@ def get_model_store():
164193
rerun()
165194
else:
166195
st.write(":octagonal_sign: Pick a model & version!")
167-
196+
168197
with st.beta_expander(label="Get model details", expanded=False):
169198

170199
st.header("Get model details")
171-
model_name = st.selectbox("Choose model", [default_key] + loaded_models_names, index=0)
200+
model_name = st.selectbox(
201+
"Choose model", [default_key] + loaded_models_names, index=0
202+
)
172203
if model_name != default_key:
173-
default_version = tsa.get_model(M_API,model_name)[0]["modelVersion"]
204+
default_version = tsa.get_model(M_API, model_name)[0]["modelVersion"]
174205
st.write(f"default version {default_version}")
175-
versions = tsa.get_model(M_API,model_name, list_all=False)
206+
versions = tsa.get_model(M_API, model_name, list_all=False)
176207
versions = [m["modelVersion"] for m in versions]
177-
version = st.selectbox("Choose version", [default_key, "All"] + versions, index=0)
208+
version = st.selectbox(
209+
"Choose version", [default_key, "All"] + versions, index=0
210+
)
178211
if model_name != default_key:
179212
if version == "All":
180213
res = tsa.get_model(M_API, model_name, list_all=True)
181214
st.write(res)
182215
elif version != default_key:
183216
res = tsa.get_model(M_API, model_name, version)
184217
st.write(res)
185-
218+
186219
with st.beta_expander(label="Scale workers", expanded=False):
187-
st.markdown("# Scale workers [(docs)](https://pytorch.org/serve/management_api.html#scale-workers)")
188-
model_name = st.selectbox("Pick model", [default_key] + loaded_models_names, index=0)
220+
st.markdown(
221+
"# Scale workers [(docs)](https://pytorch.org/serve/management_api.html#scale-workers)"
222+
)
223+
model_name = st.selectbox(
224+
"Pick model", [default_key] + loaded_models_names, index=0
225+
)
189226
if model_name != default_key:
190-
default_version = tsa.get_model(M_API,model_name)[0]["modelVersion"]
227+
default_version = tsa.get_model(M_API, model_name)[0]["modelVersion"]
191228
st.write(f"default version {default_version}")
192-
versions = tsa.get_model(M_API,model_name, list_all=False)
229+
versions = tsa.get_model(M_API, model_name, list_all=False)
193230
versions = [m["modelVersion"] for m in versions]
194231
version = st.selectbox("Choose version", ["All"] + versions, index=0)
195-
232+
196233
col1, col2, col3 = st.beta_columns(3)
197-
min_worker = col1.number_input(label="min_worker(optional)", value=-1, min_value=-1, step=1)
198-
max_worker = col2.number_input(label="max_worker(optional)", value=-1, min_value=-1, step=1)
199-
number_gpu = col3.number_input(label="number_gpu(optional)", value=-1, min_value=-1, step=1)
234+
min_worker = col1.number_input(
235+
label="min_worker(optional)", value=-1, min_value=-1, step=1
236+
)
237+
max_worker = col2.number_input(
238+
label="max_worker(optional)", value=-1, min_value=-1, step=1
239+
)
240+
# number_gpu = col3.number_input(label="number_gpu(optional)", value=-1, min_value=-1, step=1)
200241
proceed = st.button("Apply")
201242
if proceed and model_name != default_key:
202243
# number_input can't be set to None
203244
if version == "All":
204-
version=None
245+
version = None
205246
if min_worker == -1:
206-
min_worker=None
247+
min_worker = None
207248
if max_worker == -1:
208-
max_worker=None
209-
if number_gpu == -1:
210-
number_gpu=None
211-
212-
res = tsa.change_model_workers(M_API, model_name, version=version, min_worker=min_worker, max_worker=max_worker, number_gpu=number_gpu)
249+
max_worker = None
250+
# if number_gpu == -1:
251+
# number_gpu=None
252+
253+
res = tsa.change_model_workers(
254+
M_API,
255+
model_name,
256+
version=version,
257+
min_worker=min_worker,
258+
max_worker=max_worker,
259+
# number_gpu=number_gpu,
260+
)
213261
last_res()[0] = res
214262
rerun()
215-

0 commit comments

Comments
 (0)