Skip to content

Commit 9859689

Browse files
authored
Merge pull request #3 from cceyda/refactor_workflow
Update to v0.4
2 parents 2d43224 + b0ec0c6 commit 9859689

File tree

5 files changed

+327
-194
lines changed

5 files changed

+327
-194
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,5 @@ ENV/
102102
.mypy_cache/
103103

104104
# IDE settings
105-
.vscode/
105+
.vscode/
106+
/*.ipynb

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ If the server doesn't start for some reason check if your ports are already in u
5454

5555
[10-may-2021] update config & make it optional. update streamlit. Auto create folders
5656

57+
[31-may-2021] Update to v0.4 (Add workflow API) Refactor out streamlit from api.py.
58+
5759
# FAQs
5860
- **Does torchserver keep running in the background?**
5961

@@ -79,4 +81,8 @@ If the server doesn't start for some reason check if your ports are already in u
7981

8082
Open an issue
8183

84+
# TODOs
85+
- Async?
86+
- Better logging
87+
- Remote only mode
8288

setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
"""
44
from setuptools import find_packages, setup
55

6-
dependencies = ["streamlit==0.81.1", "click<8.0,>=7.0", "httpx>=0.16.0"]
6+
7+
dependencies = ["streamlit==0.82.0", "click<8.0,>=7.0", "httpx>=0.16.0"]
78

89
setup(
910
name="torchserve_dashboard",
10-
version="v0.3.3",
11+
version="v0.4.0",
1112
url="https://github.com/cceyda/torchserve-dashboard",
1213
license="Apache Software License 2.0",
1314
author="Ceyda Cinarel",
1415
author_email="[email protected]",
1516
description="Torchserve dashboard using Streamlit",
1617
long_description=__doc__,
17-
packages=find_packages(exclude=["tests","assets"]),
18+
packages=find_packages(exclude=["tests", "assets"]),
1819
include_package_data=True,
1920
zip_safe=False,
2021
platforms="any",

torchserve_dashboard/api.py

Lines changed: 182 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -3,132 +3,192 @@
33

44
import httpx
55

6-
import streamlit as st
7-
8-
ENVIRON_WHITELIST=["LD_LIBRARY_PATH","LC_CTYPE","LC_ALL","PATH","JAVA_HOME","PYTHONPATH","TS_CONFIG_FILE","LOG_LOCATION","METRICS_LOCATION"]
9-
10-
def raise_on_not200(response):
11-
if response.status_code != 200:
12-
st.write("There was an error!")
13-
st.write(response)
14-
15-
client = httpx.Client(timeout=1000, event_hooks={"response": [raise_on_not200]})
16-
17-
18-
def start_torchserve(model_store, config_path, log_location=None, metrics_location=None):
19-
new_env={}
20-
env=os.environ
21-
for x in ENVIRON_WHITELIST:
22-
if x in env:
23-
new_env[x]=env[x]
24-
if log_location:
25-
new_env["LOG_LOCATION"]=log_location
26-
if metrics_location:
27-
new_env["METRICS_LOCATION"]=metrics_location
28-
if not os.path.isdir(metrics_location):
29-
os.makedirs(metrics_location, exist_ok=True)
30-
if not os.path.isdir(log_location):
31-
os.makedirs(log_location, exist_ok=True)
32-
if not os.path.exists(model_store):
33-
return "Can't find model store path"
34-
elif not os.path.exists(config_path):
35-
return "Can't find configuration path"
36-
else:
37-
torchserve_cmd = f"torchserve --start --ncs --model-store {model_store} --ts-config {config_path}"
38-
subprocess.Popen(
6+
import logging
7+
8+
ENVIRON_WHITELIST = ["LD_LIBRARY_PATH", "LC_CTYPE", "LC_ALL", "PATH", "JAVA_HOME", "PYTHONPATH", "TS_CONFIG_FILE", "LOG_LOCATION", "METRICS_LOCATION"]
9+
10+
log = logging.getLogger(__name__)
11+
12+
13+
class LocalTS:
14+
def __init__(self, model_store, config_path, log_location=None, metrics_location=None):
15+
new_env = {}
16+
env = os.environ
17+
for x in ENVIRON_WHITELIST:
18+
if x in env:
19+
new_env[x] = env[x]
20+
if log_location:
21+
new_env["LOG_LOCATION"] = log_location
22+
if metrics_location:
23+
new_env["METRICS_LOCATION"] = metrics_location
24+
if not os.path.isdir(metrics_location):
25+
os.makedirs(metrics_location, exist_ok=True)
26+
if not os.path.isdir(log_location):
27+
os.makedirs(log_location, exist_ok=True)
28+
29+
self.model_store = model_store
30+
self.config_path = config_path
31+
self.log_location = log_location
32+
self.metrics_location = metrics_location
33+
self.env = new_env
34+
35+
def check_version(self):
36+
try:
37+
p=subprocess.run(["torchserve","--version"], check=True,
38+
stdout=subprocess.PIPE,stderr=subprocess.PIPE,
39+
universal_newlines=True)
40+
return p.stdout ,p.stderr
41+
except (subprocess.CalledProcessError,OSError) as e:
42+
return "",e
43+
44+
def start_torchserve(self):
45+
46+
if not os.path.exists(self.model_store):
47+
return "Can't find model store path"
48+
elif not os.path.exists(self.config_path):
49+
return "Can't find configuration path"
50+
dashboard_log_path = os.path.join(self.log_location, "torchserve_dashboard.log")
51+
torchserve_cmd = f"torchserve --start --ncs --model-store {self.model_store} --ts-config {self.config_path}"
52+
p = subprocess.Popen(
3953
torchserve_cmd.split(" "),
40-
env=new_env,
41-
stdout=open("/dev/null", "r"),
42-
stderr=open("/dev/null", "w"),
43-
preexec_fn=os.setpgrp,
54+
env=self.env,
55+
stdout=subprocess.DEVNULL,
56+
stderr=open(dashboard_log_path, "a+"),
57+
start_new_session=True,
58+
close_fds=True # IDK stackoverflow told me to do it
4459
)
45-
return "Torchserve is starting..please refresh page"
60+
p.communicate()
61+
if p.returncode == 0:
62+
return f"Torchserve is starting (PID: {p.pid})..please refresh page"
63+
else:
64+
return f"Torchserve is already started. Check {dashboard_log_path} for errors"
65+
66+
def stop_torchserve(self):
67+
try:
68+
p=subprocess.run(["torchserve","--stop"], check=True,
69+
stdout=subprocess.PIPE,stderr=subprocess.PIPE,
70+
universal_newlines=True)
71+
return p.stdout
72+
except (subprocess.CalledProcessError,OSError) as e:
73+
return e
74+
75+
class ManagementAPI:
76+
77+
def __init__(self, address, error_callback):
78+
self.address = address
79+
self.client = httpx.Client(timeout=1000, event_hooks={"response": [error_callback]})
80+
81+
def default_error_callback(response):
82+
if response.status_code != 200:
83+
log.info(f"Warn - status code: {response.status_code},{response}")
84+
85+
def get_loaded_models(self):
86+
try:
87+
res = self.client.get(self.address + "/models")
88+
return res.json()
89+
except httpx.HTTPError:
90+
return None
91+
92+
def get_model(self, model_name, version=None, list_all=False):
93+
req_url = self.address + "/models/" + model_name
94+
if version:
95+
req_url += "/" + version
96+
elif list_all:
97+
req_url += "/all"
98+
99+
res = self.client.get(req_url)
100+
return res.json()
101+
102+
# Doesn't have version
103+
def register_model(
104+
self,
105+
mar_path,
106+
model_name=None,
107+
handler=None,
108+
runtime=None,
109+
batch_size=None,
110+
max_batch_delay=None,
111+
initial_workers=None,
112+
response_timeout=None,
113+
):
114+
115+
req_url = self.address + "/models?url=" + mar_path + "&synchronous=false"
116+
if model_name:
117+
req_url += "&model_name=" + model_name
118+
if handler:
119+
req_url += "&handler=" + handler
120+
if runtime:
121+
req_url += "&runtime=" + runtime
122+
if batch_size:
123+
req_url += "&batch_size=" + str(batch_size)
124+
if max_batch_delay:
125+
req_url += "&max_batch_delay=" + str(max_batch_delay)
126+
if initial_workers:
127+
req_url += "&initial_workers=" + str(initial_workers)
128+
if response_timeout:
129+
req_url += "&response_timeout=" + str(response_timeout)
130+
131+
res = self.client.post(req_url)
132+
return res.json()
133+
134+
def delete_model(self, model_name, version):
135+
req_url = self.address + "/models/" + model_name
136+
if version:
137+
req_url += "/" + version
138+
res = self.client.delete(req_url)
139+
return res.json()
140+
141+
def change_model_default(self, model_name, version):
142+
req_url = self.address + "/models/" + model_name
143+
if version:
144+
req_url += "/" + version
145+
req_url += "/set-default"
146+
res = self.client.put(req_url)
147+
return res.json()
148+
149+
def change_model_workers(self, model_name, version=None, min_worker=None, max_worker=None, number_gpu=None):
150+
req_url = self.address + "/models/" + model_name
151+
if version:
152+
req_url += "/" + version
153+
req_url += "?synchronous=false"
154+
if min_worker:
155+
req_url += "&min_worker=" + str(min_worker)
156+
if max_worker:
157+
req_url += "&max_worker=" + str(max_worker)
158+
if number_gpu:
159+
req_url += "&number_gpu=" + str(number_gpu)
160+
res = self.client.put(req_url)
161+
return res.json()
46162

163+
def register_workflow(
164+
self,
165+
url,
166+
workflow_name=None
167+
):
168+
req_url = self.address + "/workflows/" + url
169+
if workflow_name:
170+
req_url += "&workflow_name=" + workflow_name
171+
res = self.client.post(req_url)
172+
return res.json()
47173

48-
def stop_torchserve():
49-
subprocess.Popen(["torchserve", "--stop"])
50-
return "Torchserve stopped"
174+
def get_workflow(self, workflow_name):
175+
req_url = self.address + "/workflows/" + workflow_name
176+
res = self.client.get(req_url)
177+
return res.json()
51178

179+
def unregister_workflow(self, workflow_name):
180+
req_url = self.address + "/workflows/" + workflow_name
181+
res = self.client.delete(req_url)
182+
return res.json()
52183

53-
def get_loaded_models(M_API):
54-
try:
55-
res = client.get(M_API + "/models")
184+
def list_workflows(self, limit=None, next_page_token=None):
185+
req_url = self.address + "/workflows/"
186+
if limit:
187+
req_url += "&limit=" + str(limit)
188+
if next_page_token:
189+
req_url += "&next_page_token=" + str(next_page_token)
190+
try:
191+
res = self.client.get(req_url)
192+
except httpx.HTTPError:
193+
return None
56194
return res.json()
57-
except httpx.HTTPError as exc:
58-
return None
59-
60-
61-
def get_model(M_API, model_name, version=None, list_all=False):
62-
req_url = M_API + "/models/" + model_name
63-
if version:
64-
req_url += "/" + version
65-
elif list_all:
66-
req_url += "/all"
67-
68-
res = client.get(req_url)
69-
return res.json()
70-
71-
72-
def register_model(
73-
M_API,
74-
mar_path,
75-
model_name=None,
76-
version=None,
77-
handler=None,
78-
runtime=None,
79-
batch_size=None,
80-
max_batch_delay=None,
81-
initial_workers=None,
82-
response_timeout=None,
83-
):
84-
85-
req_url = M_API + "/models?url=" + mar_path + "&synchronous=false"
86-
if model_name:
87-
req_url += "&model_name=" + model_name
88-
if handler:
89-
req_url += "&handler=" + handler
90-
if runtime:
91-
req_url += "&runtime=" + runtime
92-
if batch_size:
93-
req_url += "&batch_size=" + str(batch_size)
94-
if max_batch_delay:
95-
req_url += "&max_batch_delay=" + str(max_batch_delay)
96-
if initial_workers:
97-
req_url += "&initial_workers=" + str(initial_workers)
98-
if response_timeout:
99-
req_url += "&response_timeout=" + str(response_timeout)
100-
101-
res = client.post(req_url)
102-
return res.json()
103-
104-
105-
def delete_model(M_API, model_name, version):
106-
req_url = M_API + "/models/" + model_name
107-
if version:
108-
req_url += "/" + version
109-
res = client.delete(req_url)
110-
return res.json()
111-
112-
113-
def change_model_default(M_API, model_name, version):
114-
req_url = M_API + "/models/" + model_name
115-
if version:
116-
req_url += "/" + version
117-
req_url += "/set-default"
118-
res = client.put(req_url)
119-
return res.json()
120-
121-
122-
def change_model_workers(M_API, model_name, version=None, min_worker=None, max_worker=None, number_gpu=None):
123-
req_url = M_API + "/models/" + model_name
124-
if version:
125-
req_url += "/" + version
126-
req_url += "?synchronous=false"
127-
if min_worker:
128-
req_url += "&min_worker=" + str(min_worker)
129-
if max_worker:
130-
req_url += "&max_worker=" + str(max_worker)
131-
if number_gpu:
132-
req_url += "&number_gpu=" + str(number_gpu)
133-
res = client.put(req_url)
134-
return res.json()

0 commit comments

Comments
 (0)