Skip to content

Commit f99c036

Browse files
author
Ceyda Cinarel
committed
add scale workers tab
1 parent 18da1be commit f99c036

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

torchserve_dashboard/api.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import streamlit as st
77

8+
89
def raise_on_not200(response):
910
if response.status_code != 200:
1011
st.write("There was an error!")
@@ -98,3 +99,18 @@ def change_model_default(M_API, model_name, version):
9899
req_url += "/set-default"
99100
res = client.put(req_url)
100101
return res.json()
102+
103+
104+
def change_model_workers(M_API, model_name, version=None, min_worker=None, max_worker=None, number_gpu=None):
105+
req_url = M_API + "/models/" + model_name
106+
if version:
107+
req_url += "/" + version
108+
req_url += "?synchronous=false"
109+
if min_worker:
110+
req_url += "&min_worker=" + str(min_worker)
111+
if max_worker:
112+
req_url += "&max_worker=" + str(max_worker)
113+
if number_gpu:
114+
req_url += "&number_gpu=" + str(number_gpu)
115+
res = client.put(req_url)
116+
return res.json()

torchserve_dashboard/dash.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def get_model_store():
144144
if proceed and model_name != default_key and version != default_key:
145145
res = tsa.delete_model(M_API, model_name, version)
146146
last_res()[0] = res
147-
proceed=False
148147
rerun()
149148

150149
with st.beta_expander(label="Get model details", expanded=False):
@@ -164,3 +163,34 @@ def get_model_store():
164163
elif version != default_key:
165164
res = tsa.get_model(M_API, model_name, version)
166165
st.write(res)
166+
167+
with st.beta_expander(label="Scale workers", expanded=False):
168+
st.markdown("# Scale workers [(docs)](https://pytorch.org/serve/management_api.html#scale-workers)")
169+
model_name = st.selectbox("Pick model", [default_key] + loaded_models_names, index=0)
170+
if model_name != default_key:
171+
default_version = tsa.get_model(M_API,model_name)[0]["modelVersion"]
172+
st.write(f"default version {default_version}")
173+
versions = tsa.get_model(M_API,model_name, list_all=False)
174+
versions = [m["modelVersion"] for m in versions]
175+
version = st.selectbox("Choose version", ["All"] + versions, index=0)
176+
177+
col1, col2, col3 = st.beta_columns(3)
178+
min_worker = col1.number_input(label="min_worker(optional)", value=-1, min_value=-1, step=1)
179+
max_worker = col2.number_input(label="max_worker(optional)", value=-1, min_value=-1, step=1)
180+
number_gpu = col3.number_input(label="number_gpu(optional)", value=-1, min_value=-1, step=1)
181+
proceed = st.button("Apply")
182+
if proceed and model_name != default_key:
183+
# number_input can't be set to None
184+
if version == "All":
185+
version=None
186+
if min_worker == -1:
187+
min_worker=None
188+
if max_worker == -1:
189+
max_worker=None
190+
if number_gpu == -1:
191+
number_gpu=None
192+
193+
res = tsa.change_model_workers(M_API, model_name, version=version, min_worker=min_worker, max_worker=max_worker, number_gpu=number_gpu)
194+
last_res()[0] = res
195+
rerun()
196+

0 commit comments

Comments
 (0)