@@ -144,7 +144,6 @@ def get_model_store():
144144 if proceed and model_name != default_key and version != default_key :
145145 res = tsa .delete_model (M_API , model_name , version )
146146 last_res ()[0 ] = res
147- proceed = False
148147 rerun ()
149148
150149 with st .beta_expander (label = "Get model details" , expanded = False ):
@@ -164,3 +163,34 @@ def get_model_store():
164163 elif version != default_key :
165164 res = tsa .get_model (M_API , model_name , version )
166165 st .write (res )
166+
167+ with st .beta_expander (label = "Scale workers" , expanded = False ):
168+ st .markdown ("# Scale workers [(docs)](https://pytorch.org/serve/management_api.html#scale-workers)" )
169+ model_name = st .selectbox ("Pick model" , [default_key ] + loaded_models_names , index = 0 )
170+ if model_name != default_key :
171+ default_version = tsa .get_model (M_API ,model_name )[0 ]["modelVersion" ]
172+ st .write (f"default version { default_version } " )
173+ versions = tsa .get_model (M_API ,model_name , list_all = False )
174+ versions = [m ["modelVersion" ] for m in versions ]
175+ version = st .selectbox ("Choose version" , ["All" ] + versions , index = 0 )
176+
177+ col1 , col2 , col3 = st .beta_columns (3 )
178+ min_worker = col1 .number_input (label = "min_worker(optional)" , value = - 1 , min_value = - 1 , step = 1 )
179+ max_worker = col2 .number_input (label = "max_worker(optional)" , value = - 1 , min_value = - 1 , step = 1 )
180+ number_gpu = col3 .number_input (label = "number_gpu(optional)" , value = - 1 , min_value = - 1 , step = 1 )
181+ proceed = st .button ("Apply" )
182+ if proceed and model_name != default_key :
183+ # number_input can't be set to None
184+ if version == "All" :
185+ version = None
186+ if min_worker == - 1 :
187+ min_worker = None
188+ if max_worker == - 1 :
189+ max_worker = None
190+ if number_gpu == - 1 :
191+ number_gpu = None
192+
193+ res = tsa .change_model_workers (M_API , model_name , version = version , min_worker = min_worker , max_worker = max_worker , number_gpu = number_gpu )
194+ last_res ()[0 ] = res
195+ rerun ()
196+
0 commit comments