11import os
22import subprocess
3+ from typing import Any , Dict , List , Optional , Tuple , Union , Callable
34
45import httpx
6+ from httpx import Response
57
68import logging
79
8- ENVIRON_WHITELIST = ["LD_LIBRARY_PATH" , "LC_CTYPE" , "LC_ALL" , "PATH" , "JAVA_HOME" , "PYTHONPATH" , "TS_CONFIG_FILE" , "LOG_LOCATION" , "METRICS_LOCATION" ]
10+ ENVIRON_WHITELIST = [
11+ "LD_LIBRARY_PATH" , "LC_CTYPE" , "LC_ALL" , "PATH" , "JAVA_HOME" , "PYTHONPATH" ,
12+ "TS_CONFIG_FILE" , "LOG_LOCATION" , "METRICS_LOCATION"
13+ ]
914
1015log = logging .getLogger (__name__ )
1116
1217
1318class LocalTS :
14- def __init__ (self , model_store , config_path , log_location = None , metrics_location = None ):
19+ def __init__ (self ,
20+ model_store : str ,
21+ config_path : str ,
22+ log_location : Optional [str ] = None ,
23+ metrics_location : Optional [str ] = None ) -> None :
1524 new_env = {}
1625 env = os .environ
1726 for x in ENVIRON_WHITELIST :
1827 if x in env :
1928 new_env [x ] = env [x ]
2029 if log_location :
2130 new_env ["LOG_LOCATION" ] = log_location
31+ if not os .path .isdir (log_location ):
32+ os .makedirs (log_location , exist_ok = True )
2233 if metrics_location :
2334 new_env ["METRICS_LOCATION" ] = metrics_location
24- if not os .path .isdir (metrics_location ):
25- os .makedirs (metrics_location , exist_ok = True )
26- if not os .path .isdir (log_location ):
27- os .makedirs (log_location , exist_ok = True )
35+ if not os .path .isdir (metrics_location ):
36+ os .makedirs (metrics_location , exist_ok = True )
2837
2938 self .model_store = model_store
3039 self .config_path = config_path
3140 self .log_location = log_location
3241 self .metrics_location = metrics_location
3342 self .env = new_env
34-
35- def check_version (self ):
43+
44+ def check_version (self ) -> Tuple [ str , Union [ str , Exception ]] :
3645 try :
37- p = subprocess .run (["torchserve" ,"--version" ], check = True ,
38- stdout = subprocess .PIPE ,stderr = subprocess .PIPE ,
39- universal_newlines = True )
40- return p .stdout ,p .stderr
41- except (subprocess .CalledProcessError ,OSError ) as e :
42- return "" ,e
43-
44- def start_torchserve (self ):
46+ p = subprocess .run (["torchserve" , "--version" ],
47+ check = True ,
48+ stdout = subprocess .PIPE ,
49+ stderr = subprocess .PIPE ,
50+ universal_newlines = True )
51+ return p .stdout , p .stderr
52+ except (subprocess .CalledProcessError , OSError ) as e :
53+ return "" , e
54+
55+ def start_torchserve (self ) -> str :
4556
4657 if not os .path .exists (self .model_store ):
4758 return "Can't find model store path"
4859 elif not os .path .exists (self .config_path ):
4960 return "Can't find configuration path"
50- dashboard_log_path = os .path .join (self .log_location , "torchserve_dashboard.log" )
61+ dashboard_log_path = os .path .join (
62+ self .log_location , "torchserve_dashboard.log"
63+ ) if self .log_location is not None else None
5164 torchserve_cmd = f"torchserve --start --ncs --model-store { self .model_store } --ts-config { self .config_path } "
5265 p = subprocess .Popen (
5366 torchserve_cmd .split (" " ),
5467 env = self .env ,
5568 stdout = subprocess .DEVNULL ,
56- stderr = open (dashboard_log_path , "a+" ),
69+ stderr = open (dashboard_log_path , "a+" )
70+ if dashboard_log_path else subprocess .DEVNULL ,
5771 start_new_session = True ,
5872 close_fds = True # IDK stackoverflow told me to do it
5973 )
@@ -63,33 +77,41 @@ def start_torchserve(self):
6377 else :
6478 return f"Torchserve is already started. Check { dashboard_log_path } for errors"
6579
66- def stop_torchserve (self ):
80+ def stop_torchserve (self ) -> Union [ str , Exception ] :
6781 try :
68- p = subprocess .run (["torchserve" ,"--stop" ], check = True ,
69- stdout = subprocess .PIPE ,stderr = subprocess .PIPE ,
70- universal_newlines = True )
82+ p = subprocess .run (["torchserve" , "--stop" ],
83+ check = True ,
84+ stdout = subprocess .PIPE ,
85+ stderr = subprocess .PIPE ,
86+ universal_newlines = True )
7187 return p .stdout
72- except (subprocess .CalledProcessError ,OSError ) as e :
88+ except (subprocess .CalledProcessError , OSError ) as e :
7389 return e
7490
75- class ManagementAPI :
7691
77- def __init__ (self , address , error_callback ):
92+ class ManagementAPI :
93+ def __init__ (self , address : str , error_callback : Callable = None ) -> None :
7894 self .address = address
79- self .client = httpx .Client (timeout = 1000 , event_hooks = {"response" : [error_callback ]})
80-
81- def default_error_callback (response ):
95+ if not error_callback :
96+ error_callback = self .default_error_callback
97+ self .client = httpx .Client (timeout = 1000 ,
98+ event_hooks = {"response" : [error_callback ]})
99+ @staticmethod
100+ def default_error_callback (response : Response ) -> None :
82101 if response .status_code != 200 :
83102 log .info (f"Warn - status code: { response .status_code } ,{ response } " )
84103
85- def get_loaded_models (self ):
104+ def get_loaded_models (self ) -> Optional [ Dict [ str , Any ]] :
86105 try :
87106 res = self .client .get (self .address + "/models" )
88107 return res .json ()
89108 except httpx .HTTPError :
90109 return None
91110
92- def get_model (self , model_name , version = None , list_all = False ):
111+ def get_model (self ,
112+ model_name : str ,
113+ version : Optional [str ] = None ,
114+ list_all : bool = False ) -> List [Dict [str , Any ]]:
93115 req_url = self .address + "/models/" + model_name
94116 if version :
95117 req_url += "/" + version
@@ -102,15 +124,15 @@ def get_model(self, model_name, version=None, list_all=False):
102124 # Doesn't have version
103125 def register_model (
104126 self ,
105- mar_path ,
106- model_name = None ,
107- handler = None ,
108- runtime = None ,
109- batch_size = None ,
110- max_batch_delay = None ,
111- initial_workers = None ,
112- response_timeout = None ,
113- ):
127+ mar_path : str ,
128+ model_name : Optional [ str ] = None ,
129+ handler : Optional [ str ] = None ,
130+ runtime : Optional [ str ] = None ,
131+ batch_size : Optional [ int ] = None ,
132+ max_batch_delay : Optional [ int ] = None ,
133+ initial_workers : Optional [ int ] = None ,
134+ response_timeout : Optional [ int ] = None ,
135+ ) -> Dict [ str , str ] :
114136
115137 req_url = self .address + "/models?url=" + mar_path + "&synchronous=false"
116138 if model_name :
@@ -131,22 +153,32 @@ def register_model(
131153 res = self .client .post (req_url )
132154 return res .json ()
133155
134- def delete_model (self , model_name , version ):
156+ def delete_model (self ,
157+ model_name : str ,
158+ version : Optional [str ] = None ) -> Dict [str , str ]:
135159 req_url = self .address + "/models/" + model_name
136160 if version :
137161 req_url += "/" + version
138162 res = self .client .delete (req_url )
139163 return res .json ()
140164
141- def change_model_default (self , model_name , version ):
165+ def change_model_default (self ,
166+ model_name : str ,
167+ version : Optional [str ] = None ):
142168 req_url = self .address + "/models/" + model_name
143169 if version :
144170 req_url += "/" + version
145171 req_url += "/set-default"
146172 res = self .client .put (req_url )
147173 return res .json ()
148174
149- def change_model_workers (self , model_name , version = None , min_worker = None , max_worker = None , number_gpu = None ):
175+ def change_model_workers (
176+ self ,
177+ model_name : str ,
178+ version : Optional [str ] = None ,
179+ min_worker : Optional [int ] = None ,
180+ max_worker : Optional [int ] = None ,
181+ number_gpu : Optional [int ] = None ) -> Dict [str , str ]:
150182 req_url = self .address + "/models/" + model_name
151183 if version :
152184 req_url += "/" + version
@@ -160,28 +192,30 @@ def change_model_workers(self, model_name, version=None, min_worker=None, max_wo
160192 res = self .client .put (req_url )
161193 return res .json ()
162194
163- def register_workflow (
164- self ,
165- url ,
166- workflow_name = None
167- ):
195+ def register_workflow (self ,
196+ url : str ,
197+ workflow_name : Optional [str ] = None ) -> Dict [str , str ]:
168198 req_url = self .address + "/workflows/" + url
169199 if workflow_name :
170200 req_url += "&workflow_name=" + workflow_name
171201 res = self .client .post (req_url )
172202 return res .json ()
173203
174- def get_workflow (self , workflow_name ) :
204+ def get_workflow (self , workflow_name : str ) -> Dict [ str , str ] :
175205 req_url = self .address + "/workflows/" + workflow_name
176206 res = self .client .get (req_url )
177207 return res .json ()
178208
179- def unregister_workflow (self , workflow_name ) :
209+ def unregister_workflow (self , workflow_name : str ) -> Dict [ str , str ] :
180210 req_url = self .address + "/workflows/" + workflow_name
181211 res = self .client .delete (req_url )
182212 return res .json ()
183213
184- def list_workflows (self , limit = None , next_page_token = None ):
214+ def list_workflows (
215+ self ,
216+ limit : Optional [int ] = None ,
217+ next_page_token : Optional [int ] = None
218+ ) -> Optional [Dict [str , Any ]]:
185219 req_url = self .address + "/workflows/"
186220 if limit :
187221 req_url += "&limit=" + str (limit )
0 commit comments