|
3 | 3 |
|
4 | 4 | import httpx |
5 | 5 |
|
6 | | -import streamlit as st |
7 | | - |
8 | | -ENVIRON_WHITELIST=["LD_LIBRARY_PATH","LC_CTYPE","LC_ALL","PATH","JAVA_HOME","PYTHONPATH","TS_CONFIG_FILE","LOG_LOCATION","METRICS_LOCATION"] |
9 | | - |
10 | | -def raise_on_not200(response): |
11 | | - if response.status_code != 200: |
12 | | - st.write("There was an error!") |
13 | | - st.write(response) |
14 | | - |
15 | | -client = httpx.Client(timeout=1000, event_hooks={"response": [raise_on_not200]}) |
16 | | - |
17 | | - |
18 | | -def start_torchserve(model_store, config_path, log_location=None, metrics_location=None): |
19 | | - new_env={} |
20 | | - env=os.environ |
21 | | - for x in ENVIRON_WHITELIST: |
22 | | - if x in env: |
23 | | - new_env[x]=env[x] |
24 | | - if log_location: |
25 | | - new_env["LOG_LOCATION"]=log_location |
26 | | - if metrics_location: |
27 | | - new_env["METRICS_LOCATION"]=metrics_location |
28 | | - if not os.path.isdir(metrics_location): |
29 | | - os.makedirs(metrics_location, exist_ok=True) |
30 | | - if not os.path.isdir(log_location): |
31 | | - os.makedirs(log_location, exist_ok=True) |
32 | | - if not os.path.exists(model_store): |
33 | | - return "Can't find model store path" |
34 | | - elif not os.path.exists(config_path): |
35 | | - return "Can't find configuration path" |
36 | | - else: |
37 | | - torchserve_cmd = f"torchserve --start --ncs --model-store {model_store} --ts-config {config_path}" |
38 | | - subprocess.Popen( |
| 6 | +import logging |
| 7 | + |
| 8 | +ENVIRON_WHITELIST = ["LD_LIBRARY_PATH", "LC_CTYPE", "LC_ALL", "PATH", "JAVA_HOME", "PYTHONPATH", "TS_CONFIG_FILE", "LOG_LOCATION", "METRICS_LOCATION"] |
| 9 | + |
| 10 | +log = logging.getLogger(__name__) |
| 11 | + |
| 12 | + |
| 13 | +class LocalTS: |
| 14 | + def __init__(self, model_store, config_path, log_location=None, metrics_location=None): |
| 15 | + new_env = {} |
| 16 | + env = os.environ |
| 17 | + for x in ENVIRON_WHITELIST: |
| 18 | + if x in env: |
| 19 | + new_env[x] = env[x] |
| 20 | + if log_location: |
| 21 | + new_env["LOG_LOCATION"] = log_location |
| 22 | + if metrics_location: |
| 23 | + new_env["METRICS_LOCATION"] = metrics_location |
| 24 | + if not os.path.isdir(metrics_location): |
| 25 | + os.makedirs(metrics_location, exist_ok=True) |
| 26 | + if not os.path.isdir(log_location): |
| 27 | + os.makedirs(log_location, exist_ok=True) |
| 28 | + |
| 29 | + self.model_store = model_store |
| 30 | + self.config_path = config_path |
| 31 | + self.log_location = log_location |
| 32 | + self.metrics_location = metrics_location |
| 33 | + self.env = new_env |
| 34 | + |
| 35 | + def check_version(self): |
| 36 | + try: |
| 37 | + p=subprocess.run(["torchserve","--version"], check=True, |
| 38 | + stdout=subprocess.PIPE,stderr=subprocess.PIPE, |
| 39 | + universal_newlines=True) |
| 40 | + return p.stdout ,p.stderr |
| 41 | + except (subprocess.CalledProcessError,OSError) as e: |
| 42 | + return "",e |
| 43 | + |
| 44 | + def start_torchserve(self): |
| 45 | + |
| 46 | + if not os.path.exists(self.model_store): |
| 47 | + return "Can't find model store path" |
| 48 | + elif not os.path.exists(self.config_path): |
| 49 | + return "Can't find configuration path" |
| 50 | + dashboard_log_path = os.path.join(self.log_location, "torchserve_dashboard.log") |
| 51 | + torchserve_cmd = f"torchserve --start --ncs --model-store {self.model_store} --ts-config {self.config_path}" |
| 52 | + p = subprocess.Popen( |
39 | 53 | torchserve_cmd.split(" "), |
40 | | - env=new_env, |
41 | | - stdout=open("/dev/null", "r"), |
42 | | - stderr=open("/dev/null", "w"), |
43 | | - preexec_fn=os.setpgrp, |
| 54 | + env=self.env, |
| 55 | + stdout=subprocess.DEVNULL, |
| 56 | + stderr=open(dashboard_log_path, "a+"), |
| 57 | + start_new_session=True, |
| 58 | + close_fds=True # IDK stackoverflow told me to do it |
44 | 59 | ) |
45 | | - return "Torchserve is starting..please refresh page" |
| 60 | + p.communicate() |
| 61 | + if p.returncode == 0: |
| 62 | + return f"Torchserve is starting (PID: {p.pid})..please refresh page" |
| 63 | + else: |
| 64 | + return f"Torchserve is already started. Check {dashboard_log_path} for errors" |
| 65 | + |
| 66 | + def stop_torchserve(self): |
| 67 | + try: |
| 68 | + p=subprocess.run(["torchserve","--stop"], check=True, |
| 69 | + stdout=subprocess.PIPE,stderr=subprocess.PIPE, |
| 70 | + universal_newlines=True) |
| 71 | + return p.stdout |
| 72 | + except (subprocess.CalledProcessError,OSError) as e: |
| 73 | + return e |
| 74 | + |
| 75 | +class ManagementAPI: |
| 76 | + |
| 77 | + def __init__(self, address, error_callback): |
| 78 | + self.address = address |
| 79 | + self.client = httpx.Client(timeout=1000, event_hooks={"response": [error_callback]}) |
| 80 | + |
| 81 | + def default_error_callback(response): |
| 82 | + if response.status_code != 200: |
| 83 | + log.info(f"Warn - status code: {response.status_code},{response}") |
| 84 | + |
| 85 | + def get_loaded_models(self): |
| 86 | + try: |
| 87 | + res = self.client.get(self.address + "/models") |
| 88 | + return res.json() |
| 89 | + except httpx.HTTPError: |
| 90 | + return None |
| 91 | + |
| 92 | + def get_model(self, model_name, version=None, list_all=False): |
| 93 | + req_url = self.address + "/models/" + model_name |
| 94 | + if version: |
| 95 | + req_url += "/" + version |
| 96 | + elif list_all: |
| 97 | + req_url += "/all" |
| 98 | + |
| 99 | + res = self.client.get(req_url) |
| 100 | + return res.json() |
| 101 | + |
| 102 | + # Doesn't have version |
| 103 | + def register_model( |
| 104 | + self, |
| 105 | + mar_path, |
| 106 | + model_name=None, |
| 107 | + handler=None, |
| 108 | + runtime=None, |
| 109 | + batch_size=None, |
| 110 | + max_batch_delay=None, |
| 111 | + initial_workers=None, |
| 112 | + response_timeout=None, |
| 113 | + ): |
| 114 | + |
| 115 | + req_url = self.address + "/models?url=" + mar_path + "&synchronous=false" |
| 116 | + if model_name: |
| 117 | + req_url += "&model_name=" + model_name |
| 118 | + if handler: |
| 119 | + req_url += "&handler=" + handler |
| 120 | + if runtime: |
| 121 | + req_url += "&runtime=" + runtime |
| 122 | + if batch_size: |
| 123 | + req_url += "&batch_size=" + str(batch_size) |
| 124 | + if max_batch_delay: |
| 125 | + req_url += "&max_batch_delay=" + str(max_batch_delay) |
| 126 | + if initial_workers: |
| 127 | + req_url += "&initial_workers=" + str(initial_workers) |
| 128 | + if response_timeout: |
| 129 | + req_url += "&response_timeout=" + str(response_timeout) |
| 130 | + |
| 131 | + res = self.client.post(req_url) |
| 132 | + return res.json() |
| 133 | + |
| 134 | + def delete_model(self, model_name, version): |
| 135 | + req_url = self.address + "/models/" + model_name |
| 136 | + if version: |
| 137 | + req_url += "/" + version |
| 138 | + res = self.client.delete(req_url) |
| 139 | + return res.json() |
| 140 | + |
| 141 | + def change_model_default(self, model_name, version): |
| 142 | + req_url = self.address + "/models/" + model_name |
| 143 | + if version: |
| 144 | + req_url += "/" + version |
| 145 | + req_url += "/set-default" |
| 146 | + res = self.client.put(req_url) |
| 147 | + return res.json() |
| 148 | + |
| 149 | + def change_model_workers(self, model_name, version=None, min_worker=None, max_worker=None, number_gpu=None): |
| 150 | + req_url = self.address + "/models/" + model_name |
| 151 | + if version: |
| 152 | + req_url += "/" + version |
| 153 | + req_url += "?synchronous=false" |
| 154 | + if min_worker: |
| 155 | + req_url += "&min_worker=" + str(min_worker) |
| 156 | + if max_worker: |
| 157 | + req_url += "&max_worker=" + str(max_worker) |
| 158 | + if number_gpu: |
| 159 | + req_url += "&number_gpu=" + str(number_gpu) |
| 160 | + res = self.client.put(req_url) |
| 161 | + return res.json() |
46 | 162 |
|
| 163 | + def register_workflow( |
| 164 | + self, |
| 165 | + url, |
| 166 | + workflow_name=None |
| 167 | + ): |
| 168 | + req_url = self.address + "/workflows/" + url |
| 169 | + if workflow_name: |
| 170 | + req_url += "&workflow_name=" + workflow_name |
| 171 | + res = self.client.post(req_url) |
| 172 | + return res.json() |
47 | 173 |
|
48 | | -def stop_torchserve(): |
49 | | - subprocess.Popen(["torchserve", "--stop"]) |
50 | | - return "Torchserve stopped" |
| 174 | + def get_workflow(self, workflow_name): |
| 175 | + req_url = self.address + "/workflows/" + workflow_name |
| 176 | + res = self.client.get(req_url) |
| 177 | + return res.json() |
51 | 178 |
|
| 179 | + def unregister_workflow(self, workflow_name): |
| 180 | + req_url = self.address + "/workflows/" + workflow_name |
| 181 | + res = self.client.delete(req_url) |
| 182 | + return res.json() |
52 | 183 |
|
53 | | -def get_loaded_models(M_API): |
54 | | - try: |
55 | | - res = client.get(M_API + "/models") |
| 184 | + def list_workflows(self, limit=None, next_page_token=None): |
| 185 | + req_url = self.address + "/workflows/" |
| 186 | + if limit: |
| 187 | + req_url += "&limit=" + str(limit) |
| 188 | + if next_page_token: |
| 189 | + req_url += "&next_page_token=" + str(next_page_token) |
| 190 | + try: |
| 191 | + res = self.client.get(req_url) |
| 192 | + except httpx.HTTPError: |
| 193 | + return None |
56 | 194 | return res.json() |
57 | | - except httpx.HTTPError as exc: |
58 | | - return None |
59 | | - |
60 | | - |
61 | | -def get_model(M_API, model_name, version=None, list_all=False): |
62 | | - req_url = M_API + "/models/" + model_name |
63 | | - if version: |
64 | | - req_url += "/" + version |
65 | | - elif list_all: |
66 | | - req_url += "/all" |
67 | | - |
68 | | - res = client.get(req_url) |
69 | | - return res.json() |
70 | | - |
71 | | - |
72 | | -def register_model( |
73 | | - M_API, |
74 | | - mar_path, |
75 | | - model_name=None, |
76 | | - version=None, |
77 | | - handler=None, |
78 | | - runtime=None, |
79 | | - batch_size=None, |
80 | | - max_batch_delay=None, |
81 | | - initial_workers=None, |
82 | | - response_timeout=None, |
83 | | -): |
84 | | - |
85 | | - req_url = M_API + "/models?url=" + mar_path + "&synchronous=false" |
86 | | - if model_name: |
87 | | - req_url += "&model_name=" + model_name |
88 | | - if handler: |
89 | | - req_url += "&handler=" + handler |
90 | | - if runtime: |
91 | | - req_url += "&runtime=" + runtime |
92 | | - if batch_size: |
93 | | - req_url += "&batch_size=" + str(batch_size) |
94 | | - if max_batch_delay: |
95 | | - req_url += "&max_batch_delay=" + str(max_batch_delay) |
96 | | - if initial_workers: |
97 | | - req_url += "&initial_workers=" + str(initial_workers) |
98 | | - if response_timeout: |
99 | | - req_url += "&response_timeout=" + str(response_timeout) |
100 | | - |
101 | | - res = client.post(req_url) |
102 | | - return res.json() |
103 | | - |
104 | | - |
105 | | -def delete_model(M_API, model_name, version): |
106 | | - req_url = M_API + "/models/" + model_name |
107 | | - if version: |
108 | | - req_url += "/" + version |
109 | | - res = client.delete(req_url) |
110 | | - return res.json() |
111 | | - |
112 | | - |
113 | | -def change_model_default(M_API, model_name, version): |
114 | | - req_url = M_API + "/models/" + model_name |
115 | | - if version: |
116 | | - req_url += "/" + version |
117 | | - req_url += "/set-default" |
118 | | - res = client.put(req_url) |
119 | | - return res.json() |
120 | | - |
121 | | - |
122 | | -def change_model_workers(M_API, model_name, version=None, min_worker=None, max_worker=None, number_gpu=None): |
123 | | - req_url = M_API + "/models/" + model_name |
124 | | - if version: |
125 | | - req_url += "/" + version |
126 | | - req_url += "?synchronous=false" |
127 | | - if min_worker: |
128 | | - req_url += "&min_worker=" + str(min_worker) |
129 | | - if max_worker: |
130 | | - req_url += "&max_worker=" + str(max_worker) |
131 | | - if number_gpu: |
132 | | - req_url += "&number_gpu=" + str(number_gpu) |
133 | | - res = client.put(req_url) |
134 | | - return res.json() |
|
0 commit comments