2323from collections .abc import Callable
2424from typing import Dict
2525
26+ import torch
2627from torch import nn
2728
2829from ainode .core .config import AINodeDescriptor
3233 MODEL_CONFIG_FILE_IN_JSON ,
3334 TSStatusCode ,
3435)
36+ from ainode .core .exception import ModelNotExistError
3537from ainode .core .log import Logger
3638from ainode .core .model .built_in_model_factory import (
3739 download_ltsm_if_necessary ,
@@ -104,7 +106,10 @@ def _init_model_info_map(self):
104106 future .add_done_callback (
105107 lambda f , mid = model_id : self ._callback_model_download_result (f , mid )
106108 )
107- # TODO: retrieve user-defined models
109+ # 4. retrieve user-defined models from the model directory
110+ user_defined_models = self ._retrieve_user_defined_models ()
111+ for model_id in user_defined_models :
112+ self ._model_info_map [model_id ] = user_defined_models [model_id ]
108113
109114 def _retrieve_fine_tuned_models (self ):
110115 """
@@ -174,6 +179,28 @@ def _callback_model_download_result(self, future, model_id: str):
174179 else :
175180 self ._model_info_map [model_id ].state = ModelStates .INACTIVE
176181
182+ def _retrieve_user_defined_models (self ):
183+ """
184+ Retrieve user_defined models from the model directory.
185+
186+ Returns:
187+ {"model_id": ModelInfo}
188+ """
189+ result = {}
190+ user_dirs = [
191+ d
192+ for d in os .listdir (self ._model_dir )
193+ if os .path .isdir (os .path .join (self ._model_dir , d )) and d != "weights"
194+ ]
195+ for model_id in user_dirs :
196+ result [model_id ] = ModelInfo (
197+ model_id = model_id ,
198+ model_type = "" ,
199+ category = ModelCategory .USER_DEFINED ,
200+ state = ModelStates .ACTIVE ,
201+ )
202+ return result
203+
177204 def register_model (self , model_id : str , uri : str ):
178205 """
179206 Args:
@@ -190,7 +217,16 @@ def register_model(self, model_id: str, uri: str):
190217 os .makedirs (storage_path )
191218 model_storage_path = os .path .join (storage_path , DEFAULT_MODEL_FILE_NAME )
192219 config_storage_path = os .path .join (storage_path , DEFAULT_CONFIG_FILE_NAME )
193- return fetch_model_by_uri (uri , model_storage_path , config_storage_path )
220+ configs , attributes = fetch_model_by_uri (
221+ uri , model_storage_path , config_storage_path
222+ )
223+ self ._model_info_map [model_id ] = ModelInfo (
224+ model_id = model_id ,
225+ model_type = "" ,
226+ category = ModelCategory .USER_DEFINED ,
227+ state = ModelStates .ACTIVE ,
228+ )
229+ return configs , attributes
194230
195231 def delete_model (self , model_id : str ) -> None :
196232 """
@@ -241,9 +277,26 @@ def load_model(self, model_id: str, acceleration: bool) -> Callable:
241277 model_dir ,
242278 )
243279 else :
244- # TODO: support load the user-defined model
245- # model_dir = os.path.join(self._model_dir, f"{model_id}")
246- raise NotImplementedError
280+ # load the user-defined model
281+ model_dir = os .path .join (self ._model_dir , f"{ model_id } " )
282+ model_path = os .path .join (model_dir , DEFAULT_MODEL_FILE_NAME )
283+
284+ if not os .path .exists (model_path ):
285+ raise ModelNotExistError (model_path )
286+ model = torch .jit .load (model_path )
287+ if (
288+ isinstance (model , torch ._dynamo .eval_frame .OptimizedModule )
289+ or not acceleration
290+ ):
291+ return model
292+
293+ try :
294+ model = torch .compile (model )
295+ except Exception as e :
296+ logger .warning (
297+ f"acceleration failed, fallback to normal mode: { str (e )} "
298+ )
299+ return model
247300
248301 def save_model (self , model_id : str , model : nn .Module ):
249302 """
@@ -257,9 +310,19 @@ def save_model(self, model_id: str, model: nn.Module):
257310 model_dir = os .path .join (self ._builtin_model_dir , f"{ model_id } " )
258311 model .save_pretrained (model_dir )
259312 else :
260- # TODO: support save the user-defined model
261- # model_dir = os.path.join(self._model_dir, f"{model_id}")
262- raise NotImplementedError
313+ # save the user-defined model
314+ model_dir = os .path .join (self ._model_dir , f"{ model_id } " )
315+ os .makedirs (model_dir , exist_ok = True )
316+ model_path = os .path .join (model_dir , DEFAULT_MODEL_FILE_NAME )
317+ try :
318+ scripted_model = (
319+ model
320+ if isinstance (model , torch .jit .ScriptModule )
321+ else torch .jit .script (model )
322+ )
323+ torch .jit .save (scripted_model , model_path )
324+ except Exception as e :
325+ logger .error (f"Failed to save scripted model: { e } " )
263326
264327 def get_ckpt_path (self , model_id : str ) -> str :
265328 """
0 commit comments