5050 get_built_in_model_type ,
5151)
5252from ainode .core .util .lock import ModelLockPool
53- from ainode .thrift .ainode .ttypes import TShowModelsResp
53+ from ainode .thrift .ainode .ttypes import TShowModelsReq , TShowModelsResp
5454from ainode .thrift .common .ttypes import TSStatus
5555
5656logger = Logger ()
@@ -211,23 +211,30 @@ def register_model(self, model_id: str, uri: str):
211211 configs: TConfigs
212212 attributes: str
213213 """
214- storage_path = os .path .join (self ._model_dir , f"{ model_id } " )
215- # create storage dir if not exist
216- if not os .path .exists (storage_path ):
217- os .makedirs (storage_path )
218- model_storage_path = os .path .join (storage_path , DEFAULT_MODEL_FILE_NAME )
219- config_storage_path = os .path .join (storage_path , DEFAULT_CONFIG_FILE_NAME )
220- configs , attributes = fetch_model_by_uri (
221- uri , model_storage_path , config_storage_path
222- )
223- model_info = ModelInfo (
224- model_id = model_id ,
225- model_type = "" ,
226- category = ModelCategory .USER_DEFINED ,
227- state = ModelStates .ACTIVE ,
228- )
229- self .register_built_in_model (model_info )
230- return configs , attributes
214+ with self ._lock_pool .get_lock (model_id ).write_lock ():
215+ storage_path = os .path .join (self ._model_dir , f"{ model_id } " )
216+ # create storage dir if not exist
217+ if not os .path .exists (storage_path ):
218+ os .makedirs (storage_path )
219+ model_storage_path = os .path .join (storage_path , DEFAULT_MODEL_FILE_NAME )
220+ config_storage_path = os .path .join (storage_path , DEFAULT_CONFIG_FILE_NAME )
221+ self ._model_info_map [model_id ] = ModelInfo (
222+ model_id = model_id ,
223+ model_type = "" ,
224+ category = ModelCategory .USER_DEFINED ,
225+ state = ModelStates .LOADING ,
226+ )
227+ try :
228+ # TODO: The uri should be fetched asynchronously
229+ configs , attributes = fetch_model_by_uri (
230+ uri , model_storage_path , config_storage_path
231+ )
232+ self ._model_info_map [model_id ].state = ModelStates .ACTIVE
233+ return configs , attributes
234+ except Exception as e :
235+ logger .error (f"Failed to register model { model_id } : { e } " )
236+ self ._model_info_map [model_id ].state = ModelStates .INACTIVE
237+ raise e
231238
232239 def delete_model (self , model_id : str ) -> None :
233240 """
@@ -241,9 +248,12 @@ def delete_model(self, model_id: str) -> None:
241248 if self ._is_built_in (model_id ):
242249 raise BuiltInModelDeletionError (model_id )
243250
244- # delete the user-defined model
245- storage_path = os .path .join (self ._model_dir , f"{ model_id } " )
251+ # delete the user-defined or fine-tuned model
246252 with self ._lock_pool .get_lock (model_id ).write_lock ():
253+ storage_path = os .path .join (self ._model_dir , f"{ model_id } " )
254+ if os .path .exists (storage_path ):
255+ shutil .rmtree (storage_path )
256+ storage_path = os .path .join (self ._builtin_model_dir , f"{ model_id } " )
247257 if os .path .exists (storage_path ):
248258 shutil .rmtree (storage_path )
249259 if model_id in self ._model_info_map :
@@ -260,6 +270,21 @@ def _is_built_in(self, model_id: str) -> bool:
260270 Returns:
261271 bool: True if the model is built-in, False otherwise.
262272 """
273+ return (
274+ model_id in self ._model_info_map
275+ and self ._model_info_map [model_id ].category == ModelCategory .BUILT_IN
276+ )
277+
278+ def _is_built_in_or_fine_tuned (self , model_id : str ) -> bool :
279+ """
280+ Check if the model_id corresponds to a built-in or fine-tuned model.
281+
282+ Args:
283+ model_id (str): The ID of the model.
284+
285+ Returns:
286+ bool: True if the model is built-in or fine_tuned, False otherwise.
287+ """
263288 return model_id in self ._model_info_map and (
264289 self ._model_info_map [model_id ].category == ModelCategory .BUILT_IN
265290 or self ._model_info_map [model_id ].category == ModelCategory .FINE_TUNED
@@ -275,7 +300,7 @@ def load_model(
275300 model: The model instance corresponding to specific model_id
276301 """
277302 with self ._lock_pool .get_lock (model_id ).read_lock ():
278- if self ._is_built_in (model_id ):
303+ if self ._is_built_in_or_fine_tuned (model_id ):
279304 model_dir = os .path .join (self ._builtin_model_dir , f"{ model_id } " )
280305 return fetch_built_in_model (
281306 get_built_in_model_type (self ._model_info_map [model_id ].model_type ),
@@ -312,7 +337,7 @@ def save_model(self, model_id: str, model: nn.Module):
312337 Whether saving succeeded
313338 """
314339 with self ._lock_pool .get_lock (model_id ).write_lock ():
315- if self ._is_built_in (model_id ):
340+ if self ._is_built_in_or_fine_tuned (model_id ):
316341 model_dir = os .path .join (self ._builtin_model_dir , f"{ model_id } " )
317342 model .save_pretrained (model_dir )
318343 else :
@@ -343,12 +368,31 @@ def get_ckpt_path(self, model_id: str) -> str:
343368 # Only support built-in models for now
344369 return os .path .join (self ._builtin_model_dir , f"{ model_id } " )
345370
346- def show_models (self ) -> TShowModelsResp :
371+ def show_models (self , req : TShowModelsReq ) -> TShowModelsResp :
372+ resp_status = TSStatus (
373+ code = TSStatusCode .SUCCESS_STATUS .value ,
374+ message = "Show models successfully" ,
375+ )
376+ if req .modelId :
377+ if req .modelId in self ._model_info_map :
378+ model_info = self ._model_info_map [req .modelId ]
379+ return TShowModelsResp (
380+ status = resp_status ,
381+ modelIdList = [req .modelId ],
382+ modelTypeMap = {req .modelId : model_info .model_type },
383+ categoryMap = {req .modelId : model_info .category .value },
384+ stateMap = {req .modelId : model_info .state .value },
385+ )
386+ else :
387+ return TShowModelsResp (
388+ status = resp_status ,
389+ modelIdList = [],
390+ modelTypeMap = {},
391+ categoryMap = {},
392+ stateMap = {},
393+ )
347394 return TShowModelsResp (
348- status = TSStatus (
349- code = TSStatusCode .SUCCESS_STATUS .value ,
350- message = "Show models successfully" ,
351- ),
395+ status = resp_status ,
352396 modelIdList = list (self ._model_info_map .keys ()),
353397 modelTypeMap = dict (
354398 (model_id , model_info .model_type )
0 commit comments