Skip to content

Commit be7a1de

Browse files
authored
[AINode] Fix bug that AINode cannot register and invoke user_defined model (#15849)
1 parent 71df495 commit be7a1de

File tree

1 file changed

+71
-8
lines changed

1 file changed

+71
-8
lines changed

iotdb-core/ainode/ainode/core/model/model_storage.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from collections.abc import Callable
2424
from typing import Dict
2525

26+
import torch
2627
from torch import nn
2728

2829
from ainode.core.config import AINodeDescriptor
@@ -32,6 +33,7 @@
3233
MODEL_CONFIG_FILE_IN_JSON,
3334
TSStatusCode,
3435
)
36+
from ainode.core.exception import ModelNotExistError
3537
from ainode.core.log import Logger
3638
from ainode.core.model.built_in_model_factory import (
3739
download_ltsm_if_necessary,
@@ -104,7 +106,10 @@ def _init_model_info_map(self):
104106
future.add_done_callback(
105107
lambda f, mid=model_id: self._callback_model_download_result(f, mid)
106108
)
107-
# TODO: retrieve user-defined models
109+
# 4. retrieve user-defined models from the model directory
110+
user_defined_models = self._retrieve_user_defined_models()
111+
for model_id in user_defined_models:
112+
self._model_info_map[model_id] = user_defined_models[model_id]
108113

109114
def _retrieve_fine_tuned_models(self):
110115
"""
@@ -174,6 +179,28 @@ def _callback_model_download_result(self, future, model_id: str):
174179
else:
175180
self._model_info_map[model_id].state = ModelStates.INACTIVE
176181

182+
def _retrieve_user_defined_models(self):
183+
"""
184+
Retrieve user_defined models from the model directory.
185+
186+
Returns:
187+
{"model_id": ModelInfo}
188+
"""
189+
result = {}
190+
user_dirs = [
191+
d
192+
for d in os.listdir(self._model_dir)
193+
if os.path.isdir(os.path.join(self._model_dir, d)) and d != "weights"
194+
]
195+
for model_id in user_dirs:
196+
result[model_id] = ModelInfo(
197+
model_id=model_id,
198+
model_type="",
199+
category=ModelCategory.USER_DEFINED,
200+
state=ModelStates.ACTIVE,
201+
)
202+
return result
203+
177204
def register_model(self, model_id: str, uri: str):
178205
"""
179206
Args:
@@ -190,7 +217,16 @@ def register_model(self, model_id: str, uri: str):
190217
os.makedirs(storage_path)
191218
model_storage_path = os.path.join(storage_path, DEFAULT_MODEL_FILE_NAME)
192219
config_storage_path = os.path.join(storage_path, DEFAULT_CONFIG_FILE_NAME)
193-
return fetch_model_by_uri(uri, model_storage_path, config_storage_path)
220+
configs, attributes = fetch_model_by_uri(
221+
uri, model_storage_path, config_storage_path
222+
)
223+
self._model_info_map[model_id] = ModelInfo(
224+
model_id=model_id,
225+
model_type="",
226+
category=ModelCategory.USER_DEFINED,
227+
state=ModelStates.ACTIVE,
228+
)
229+
return configs, attributes
194230

195231
def delete_model(self, model_id: str) -> None:
196232
"""
@@ -241,9 +277,26 @@ def load_model(self, model_id: str, acceleration: bool) -> Callable:
241277
model_dir,
242278
)
243279
else:
244-
# TODO: support load the user-defined model
245-
# model_dir = os.path.join(self._model_dir, f"{model_id}")
246-
raise NotImplementedError
280+
# load the user-defined model
281+
model_dir = os.path.join(self._model_dir, f"{model_id}")
282+
model_path = os.path.join(model_dir, DEFAULT_MODEL_FILE_NAME)
283+
284+
if not os.path.exists(model_path):
285+
raise ModelNotExistError(model_path)
286+
model = torch.jit.load(model_path)
287+
if (
288+
isinstance(model, torch._dynamo.eval_frame.OptimizedModule)
289+
or not acceleration
290+
):
291+
return model
292+
293+
try:
294+
model = torch.compile(model)
295+
except Exception as e:
296+
logger.warning(
297+
f"acceleration failed, fallback to normal mode: {str(e)}"
298+
)
299+
return model
247300

248301
def save_model(self, model_id: str, model: nn.Module):
249302
"""
@@ -257,9 +310,19 @@ def save_model(self, model_id: str, model: nn.Module):
257310
model_dir = os.path.join(self._builtin_model_dir, f"{model_id}")
258311
model.save_pretrained(model_dir)
259312
else:
260-
# TODO: support save the user-defined model
261-
# model_dir = os.path.join(self._model_dir, f"{model_id}")
262-
raise NotImplementedError
313+
# save the user-defined model
314+
model_dir = os.path.join(self._model_dir, f"{model_id}")
315+
os.makedirs(model_dir, exist_ok=True)
316+
model_path = os.path.join(model_dir, DEFAULT_MODEL_FILE_NAME)
317+
try:
318+
scripted_model = (
319+
model
320+
if isinstance(model, torch.jit.ScriptModule)
321+
else torch.jit.script(model)
322+
)
323+
torch.jit.save(scripted_model, model_path)
324+
except Exception as e:
325+
logger.error(f"Failed to save scripted model: {e}")
263326

264327
def get_ckpt_path(self, model_id: str) -> str:
265328
"""

0 commit comments

Comments
 (0)