66from streamlit .script_runner import RerunException
77
88import api as tsa
9+ from pathlib import Path
910
1011st .set_page_config (
1112 page_title = "Torchserve Management Dashboard" ,
1213 page_icon = "./icon.png" ,
1314 layout = "centered" ,
1415 initial_sidebar_state = "expanded" ,
1516)
16-
17+ st . write ( os . environ )
1718parser = argparse .ArgumentParser (description = "Torchserve dashboard" )
1819
1920parser .add_argument ("--model_store" , default = None , help = "Directory where your models are stored" )
2021parser .add_argument ("--config_path" , default = "./default.torchserve.properties" , help = "Torchserve config path" )
22+ parser .add_argument ("--log_location" , default = "./" , help = "Passed as environment variable LOG_LOCATION to Torchserve" )
23+ parser .add_argument ("--metrics_location" , default = "./" , help = "Passed as environment variable METRICS_LOCATION to Torchserve" )
2124try :
2225 args = parser .parse_args ()
2326except SystemExit as e :
2831M_API = "http://127.0.0.1:8081"
2932model_store = args .model_store
3033config_path = args .config_path
34+ log_location = args .log_location
35+ if log_location :
36+ log_location = str (Path (log_location ).resolve ())
37+ metrics_location = args .metrics_location
38+ if metrics_location :
39+ metrics_location = str (Path (metrics_location ).resolve ())
3140config = None
3241default_key = "None"
3342
@@ -56,16 +65,18 @@ def last_res():
5665def get_model_store ():
5766 return os .listdir (model_store )
5867
59-
68+ # As a design choice I'm leaving config_path,log_location,metrics_location non-editable from the UI as a semi-security measure (maybe?:/)
6069##########Sidebar##########
6170st .sidebar .markdown (f"## Help" )
62- st .sidebar .markdown (f"### Management API: \n { M_API } " )
63- st .sidebar .markdown (f"### Model Store Path: \n { model_store } " )
64- st .sidebar .markdown (f"### Config Path: \n { config_path } " )
71+ with st .sidebar .beta_expander (label = "Show Paths:" , expanded = False ):
72+ st .markdown (f"### Model Store Path: \n { model_store } " )
73+ st .markdown (f"### Config Path: \n { config_path } " )
74+ st .markdown (f"### Log Location: \n { log_location } " )
75+ st .markdown (f"### Metrics Location: \n { metrics_location } " )
6576
6677start = st .sidebar .button ("Start Torchserve" )
6778if start :
68- last_res ()[0 ]= tsa .start_torchserve (model_store , config_path )
79+ last_res ()[0 ]= tsa .start_torchserve (model_store , config_path , log_location , metrics_location )
6980 rerun ()
7081
7182stop = st .sidebar .button ("Stop Torchserve" )
@@ -104,7 +115,7 @@ def get_model_store():
104115 p = st .checkbox ("or use another path" )
105116 if p :
106117 mar_path = placeholder .text_input ("Input mar file path*" )
107- model_name = st .text_input (label = "Model name * " )
118+ model_name = st .text_input (label = "Model name (overrides predefined) " )
108119 col1 , col2 = st .beta_columns (2 )
109120 batch_size = col1 .number_input (label = "batch_size" , value = 0 , min_value = 0 , step = 1 )
110121 max_batch_delay = col2 .number_input (label = "max_batch_delay" , value = 0 , min_value = 0 , step = 100 )
@@ -114,21 +125,26 @@ def get_model_store():
114125 runtime = col2 .text_input (label = "runtime" )
115126
116127 proceed = st .button ("Register" )
117- if proceed and model_name and mar_path != default_key :
118- st .write (f"Registering Model...{ mar_path } as { model_name } " )
119- res = tsa .register_model (
120- M_API ,
121- mar_path ,
122- model_name ,
123- handler = handler ,
124- runtime = runtime ,
125- batch_size = batch_size ,
126- max_batch_delay = max_batch_delay ,
127- initial_workers = initial_workers ,
128- response_timeout = response_timeout ,
129- )
130- last_res ()[0 ] = res
131- rerun ()
128+ if proceed :
129+ if mar_path != default_key :
130+ st .write (f"Registering Model...{ mar_path } " )
131+ res = tsa .register_model (
132+ M_API ,
133+ mar_path ,
134+ model_name ,
135+ handler = handler ,
136+ runtime = runtime ,
137+ batch_size = batch_size ,
138+ max_batch_delay = max_batch_delay ,
139+ initial_workers = initial_workers ,
140+ response_timeout = response_timeout ,
141+ )
142+ last_res ()[0 ] = res
143+ rerun ()
144+ else :
145+ st .write (":octagonal_sign: Fill the required fileds!" )
146+
147+
132148
133149 with st .beta_expander (label = "Remove a model" , expanded = False ):
134150
@@ -141,11 +157,14 @@ def get_model_store():
141157 versions = [m ["modelVersion" ] for m in versions ]
142158 version = st .selectbox ("Choose version to remove" , [default_key ] + versions , index = 0 )
143159 proceed = st .button ("Remove" )
144- if proceed and model_name != default_key and version != default_key :
145- res = tsa .delete_model (M_API , model_name , version )
146- last_res ()[0 ] = res
147- rerun ()
148-
160+ if proceed :
161+ if model_name != default_key and version != default_key :
162+ res = tsa .delete_model (M_API , model_name , version )
163+ last_res ()[0 ] = res
164+ rerun ()
165+ else :
166+ st .write (":octagonal_sign: Pick a model & version!" )
167+
149168 with st .beta_expander (label = "Get model details" , expanded = False ):
150169
151170 st .header ("Get model details" )
0 commit comments