Skip to content

Commit 91b993b

Browse files
authored
Merge pull request #7 from FlorianMF/typing
WIP: Add type annotations
2 parents d3a9fa3 + 9e6a8c0 commit 91b993b

File tree

3 files changed

+92
-59
lines changed

3 files changed

+92
-59
lines changed

torchserve_dashboard/api.py

Lines changed: 84 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,73 @@
11
import os
22
import subprocess
3+
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
34

45
import httpx
6+
from httpx import Response
57

68
import logging
79

8-
ENVIRON_WHITELIST = ["LD_LIBRARY_PATH", "LC_CTYPE", "LC_ALL", "PATH", "JAVA_HOME", "PYTHONPATH", "TS_CONFIG_FILE", "LOG_LOCATION", "METRICS_LOCATION"]
10+
ENVIRON_WHITELIST = [
11+
"LD_LIBRARY_PATH", "LC_CTYPE", "LC_ALL", "PATH", "JAVA_HOME", "PYTHONPATH",
12+
"TS_CONFIG_FILE", "LOG_LOCATION", "METRICS_LOCATION"
13+
]
914

1015
log = logging.getLogger(__name__)
1116

1217

1318
class LocalTS:
14-
def __init__(self, model_store, config_path, log_location=None, metrics_location=None):
19+
def __init__(self,
20+
model_store: str,
21+
config_path: str,
22+
log_location: Optional[str] = None,
23+
metrics_location: Optional[str] = None) -> None:
1524
new_env = {}
1625
env = os.environ
1726
for x in ENVIRON_WHITELIST:
1827
if x in env:
1928
new_env[x] = env[x]
2029
if log_location:
2130
new_env["LOG_LOCATION"] = log_location
31+
if not os.path.isdir(log_location):
32+
os.makedirs(log_location, exist_ok=True)
2233
if metrics_location:
2334
new_env["METRICS_LOCATION"] = metrics_location
24-
if not os.path.isdir(metrics_location):
25-
os.makedirs(metrics_location, exist_ok=True)
26-
if not os.path.isdir(log_location):
27-
os.makedirs(log_location, exist_ok=True)
35+
if not os.path.isdir(metrics_location):
36+
os.makedirs(metrics_location, exist_ok=True)
2837

2938
self.model_store = model_store
3039
self.config_path = config_path
3140
self.log_location = log_location
3241
self.metrics_location = metrics_location
3342
self.env = new_env
34-
35-
def check_version(self):
43+
44+
def check_version(self) -> Tuple[str, Union[str, Exception]]:
3645
try:
37-
p=subprocess.run(["torchserve","--version"], check=True,
38-
stdout=subprocess.PIPE,stderr=subprocess.PIPE,
39-
universal_newlines=True)
40-
return p.stdout ,p.stderr
41-
except (subprocess.CalledProcessError,OSError) as e:
42-
return "",e
43-
44-
def start_torchserve(self):
46+
p = subprocess.run(["torchserve", "--version"],
47+
check=True,
48+
stdout=subprocess.PIPE,
49+
stderr=subprocess.PIPE,
50+
universal_newlines=True)
51+
return p.stdout, p.stderr
52+
except (subprocess.CalledProcessError, OSError) as e:
53+
return "", e
54+
55+
def start_torchserve(self) -> str:
4556

4657
if not os.path.exists(self.model_store):
4758
return "Can't find model store path"
4859
elif not os.path.exists(self.config_path):
4960
return "Can't find configuration path"
50-
dashboard_log_path = os.path.join(self.log_location, "torchserve_dashboard.log")
61+
dashboard_log_path = os.path.join(
62+
self.log_location, "torchserve_dashboard.log"
63+
) if self.log_location is not None else None
5164
torchserve_cmd = f"torchserve --start --ncs --model-store {self.model_store} --ts-config {self.config_path}"
5265
p = subprocess.Popen(
5366
torchserve_cmd.split(" "),
5467
env=self.env,
5568
stdout=subprocess.DEVNULL,
56-
stderr=open(dashboard_log_path, "a+"),
69+
stderr=open(dashboard_log_path, "a+")
70+
if dashboard_log_path else subprocess.DEVNULL,
5771
start_new_session=True,
5872
close_fds=True # IDK stackoverflow told me to do it
5973
)
@@ -63,33 +77,41 @@ def start_torchserve(self):
6377
else:
6478
return f"Torchserve is already started. Check {dashboard_log_path} for errors"
6579

66-
def stop_torchserve(self):
80+
def stop_torchserve(self) -> Union[str, Exception]:
6781
try:
68-
p=subprocess.run(["torchserve","--stop"], check=True,
69-
stdout=subprocess.PIPE,stderr=subprocess.PIPE,
70-
universal_newlines=True)
82+
p = subprocess.run(["torchserve", "--stop"],
83+
check=True,
84+
stdout=subprocess.PIPE,
85+
stderr=subprocess.PIPE,
86+
universal_newlines=True)
7187
return p.stdout
72-
except (subprocess.CalledProcessError,OSError) as e:
88+
except (subprocess.CalledProcessError, OSError) as e:
7389
return e
7490

75-
class ManagementAPI:
7691

77-
def __init__(self, address, error_callback):
92+
class ManagementAPI:
93+
def __init__(self, address: str, error_callback: Callable = None) -> None:
7894
self.address = address
79-
self.client = httpx.Client(timeout=1000, event_hooks={"response": [error_callback]})
80-
81-
def default_error_callback(response):
95+
if not error_callback:
96+
error_callback=self.default_error_callback
97+
self.client = httpx.Client(timeout=1000,
98+
event_hooks={"response": [error_callback]})
99+
@staticmethod
100+
def default_error_callback(response: Response) -> None:
82101
if response.status_code != 200:
83102
log.info(f"Warn - status code: {response.status_code},{response}")
84103

85-
def get_loaded_models(self):
104+
def get_loaded_models(self) -> Optional[Dict[str, Any]]:
86105
try:
87106
res = self.client.get(self.address + "/models")
88107
return res.json()
89108
except httpx.HTTPError:
90109
return None
91110

92-
def get_model(self, model_name, version=None, list_all=False):
111+
def get_model(self,
112+
model_name: str,
113+
version: Optional[str] = None,
114+
list_all: bool = False) -> List[Dict[str, Any]]:
93115
req_url = self.address + "/models/" + model_name
94116
if version:
95117
req_url += "/" + version
@@ -102,15 +124,15 @@ def get_model(self, model_name, version=None, list_all=False):
102124
# Doesn't have version
103125
def register_model(
104126
self,
105-
mar_path,
106-
model_name=None,
107-
handler=None,
108-
runtime=None,
109-
batch_size=None,
110-
max_batch_delay=None,
111-
initial_workers=None,
112-
response_timeout=None,
113-
):
127+
mar_path: str,
128+
model_name: Optional[str] = None,
129+
handler: Optional[str] = None,
130+
runtime: Optional[str] = None,
131+
batch_size: Optional[int] = None,
132+
max_batch_delay: Optional[int] = None,
133+
initial_workers: Optional[int] = None,
134+
response_timeout: Optional[int] = None,
135+
) -> Dict[str, str]:
114136

115137
req_url = self.address + "/models?url=" + mar_path + "&synchronous=false"
116138
if model_name:
@@ -131,22 +153,32 @@ def register_model(
131153
res = self.client.post(req_url)
132154
return res.json()
133155

134-
def delete_model(self, model_name, version):
156+
def delete_model(self,
157+
model_name: str,
158+
version: Optional[str] = None) -> Dict[str, str]:
135159
req_url = self.address + "/models/" + model_name
136160
if version:
137161
req_url += "/" + version
138162
res = self.client.delete(req_url)
139163
return res.json()
140164

141-
def change_model_default(self, model_name, version):
165+
def change_model_default(self,
166+
model_name: str,
167+
version: Optional[str] = None):
142168
req_url = self.address + "/models/" + model_name
143169
if version:
144170
req_url += "/" + version
145171
req_url += "/set-default"
146172
res = self.client.put(req_url)
147173
return res.json()
148174

149-
def change_model_workers(self, model_name, version=None, min_worker=None, max_worker=None, number_gpu=None):
175+
def change_model_workers(
176+
self,
177+
model_name: str,
178+
version: Optional[str] = None,
179+
min_worker: Optional[int] = None,
180+
max_worker: Optional[int] = None,
181+
number_gpu: Optional[int] = None) -> Dict[str, str]:
150182
req_url = self.address + "/models/" + model_name
151183
if version:
152184
req_url += "/" + version
@@ -160,28 +192,30 @@ def change_model_workers(self, model_name, version=None, min_worker=None, max_wo
160192
res = self.client.put(req_url)
161193
return res.json()
162194

163-
def register_workflow(
164-
self,
165-
url,
166-
workflow_name=None
167-
):
195+
def register_workflow(self,
196+
url: str,
197+
workflow_name: Optional[str] = None) -> Dict[str, str]:
168198
req_url = self.address + "/workflows/" + url
169199
if workflow_name:
170200
req_url += "&workflow_name=" + workflow_name
171201
res = self.client.post(req_url)
172202
return res.json()
173203

174-
def get_workflow(self, workflow_name):
204+
def get_workflow(self, workflow_name: str) -> Dict[str, str]:
175205
req_url = self.address + "/workflows/" + workflow_name
176206
res = self.client.get(req_url)
177207
return res.json()
178208

179-
def unregister_workflow(self, workflow_name):
209+
def unregister_workflow(self, workflow_name: str) -> Dict[str, str]:
180210
req_url = self.address + "/workflows/" + workflow_name
181211
res = self.client.delete(req_url)
182212
return res.json()
183213

184-
def list_workflows(self, limit=None, next_page_token=None):
214+
def list_workflows(
215+
self,
216+
limit: Optional[int] = None,
217+
next_page_token: Optional[int] = None
218+
) -> Optional[Dict[str, Any]]:
185219
req_url = self.address + "/workflows/"
186220
if limit:
187221
req_url += "&limit=" + str(limit)

torchserve_dashboard/cli.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1+
from typing import Any
12
import click
23
import streamlit.cli
34
from streamlit.cli import configurator_options
45
import os
56

6-
@click.command(context_settings=dict(
7-
ignore_unknown_options=True,
8-
allow_extra_args = True
9-
))
7+
8+
@click.command(context_settings=dict(ignore_unknown_options=True,
9+
allow_extra_args=True))
1010
@configurator_options
1111
@click.argument("args", nargs=-1)
1212
@click.pass_context
13-
def main(ctx,args,**kwargs):
13+
def main(ctx: click.Context, args: Any, **kwargs: Any):
1414
dirname = os.path.dirname(__file__)
1515
filename = os.path.join(dirname, 'dash.py')
1616
ctx.forward(streamlit.cli.main_run, target=filename, args=args, *kwargs)
17-

torchserve_dashboard/dash.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import os
33

44
import streamlit as st
5+
from httpx import Response
56

6-
from api import ManagementAPI, LocalTS
7+
from torchserve_dashboard.api import ManagementAPI, LocalTS
78
from pathlib import Path
89

910
st.set_page_config(
@@ -48,7 +49,6 @@ def check_args(args):
4849
if not os.path.exists(config_path):
4950
st.write(f"Can't find config file at {config_path}. Using default config instead")
5051
config_path = os.path.join(os.path.dirname(__file__), "default.torchserve.properties")
51-
5252
if os.path.exists(config_path):
5353
config = open(config_path, "r").readlines()
5454
for c in config:
@@ -85,7 +85,7 @@ def rerun():
8585
st.experimental_rerun()
8686

8787

88-
def error_callback(response):
88+
def error_callback(response:Response):
8989
if response.status_code != 200:
9090
st.write("There was an error!")
9191
st.write(response)

0 commit comments

Comments
 (0)