diff --git a/torchserve_dashboard/api.py b/torchserve_dashboard/api.py index 1199786..d7170c9 100644 --- a/torchserve_dashboard/api.py +++ b/torchserve_dashboard/api.py @@ -1,15 +1,42 @@ +import logging import os import subprocess +from typing import Any, Callable, Optional import httpx - -import logging +import streamlit as st +from httpx import Response ENVIRON_WHITELIST = ["LD_LIBRARY_PATH", "LC_CTYPE", "LC_ALL", "PATH", "JAVA_HOME", "PYTHONPATH", "TS_CONFIG_FILE", "LOG_LOCATION", "METRICS_LOCATION"] log = logging.getLogger(__name__) + +class HTTPClient(httpx.Client): + + def __init__( + self, + timeout: int = 1000, + error_callback: Optional[Callable] = None + ) -> None: + error_callback = error_callback or self.default_error_callback + super().__init__( + timeout=timeout, event_hooks={'response': [error_callback]} + ) + + @staticmethod + def default_error_callback( + response: Response, *args: Any, **kwargs: Any + ) -> None: + if response.status_code != 200: + log.info( + f"Warn - status code: {response.status_code}, {response}" + ) + st.write(f'There was an error! Status Code {response.status_code}') + st.write(response) + + class LocalTS: def __init__(self, model_store, config_path, log_location=None, metrics_location=None): new_env = {} @@ -31,7 +58,7 @@ def __init__(self, model_store, config_path, log_location=None, metrics_location self.log_location = log_location self.metrics_location = metrics_location self.env = new_env - + def check_version(self): try: p=subprocess.run(["torchserve","--version"], check=True, @@ -40,7 +67,7 @@ def check_version(self): return p.stdout ,p.stderr except (subprocess.CalledProcessError,OSError) as e: return "",e - + def start_torchserve(self): if not os.path.exists(self.model_store): @@ -74,13 +101,9 @@ def stop_torchserve(self): class ManagementAPI: - def __init__(self, address, error_callback): + def __init__(self, address, http_client): self.address = address - self.client = httpx.Client(timeout=1000, event_hooks={"response": [error_callback]}) - - def default_error_callback(response): - if response.status_code != 200: - log.info(f"Warn - status code: {response.status_code},{response}") + self.client = http_client def get_loaded_models(self): try: diff --git a/torchserve_dashboard/dash.py b/torchserve_dashboard/dash.py index be9a182..cc575ed 100644 --- a/torchserve_dashboard/dash.py +++ b/torchserve_dashboard/dash.py @@ -3,7 +3,7 @@ import streamlit as st -from api import ManagementAPI, LocalTS +from torchserve_dashboard.api import ManagementAPI, LocalTS, HTTPClient from pathlib import Path st.set_page_config( @@ -99,8 +99,8 @@ def last_res(): def get_model_store(): return os.listdir(model_store) - -api = ManagementAPI(api_address, error_callback) +http_client = HTTPClient() +api = ManagementAPI(api_address, http_client) ts = LocalTS(model_store, config_path, log_location, metrics_location) ts_version,ts_error=ts.check_version() # doing it this way rather than ts.__version__ on purpose if ts_error: