diff --git a/iotdb-core/ainode/ainode/core/config.py b/iotdb-core/ainode/ainode/core/config.py index b7dbf3fc94c21..b4694cb9c3b18 100644 --- a/iotdb-core/ainode/ainode/core/config.py +++ b/iotdb-core/ainode/ainode/core/config.py @@ -38,6 +38,7 @@ AINODE_ROOT_DIR, AINODE_SYSTEM_DIR, AINODE_SYSTEM_FILE_NAME, + AINODE_TARGET_CONFIG_NODE_LIST, AINODE_THRIFT_COMPRESSION_ENABLED, AINODE_VERSION_INFO, ) @@ -73,7 +74,7 @@ def __init__(self): self._ain_model_storage_cache_size = 30 # Target ConfigNode to be connected by AINode - self._ain_target_config_node_list: TEndPoint = TEndPoint("127.0.0.1", 10710) + self._ain_target_config_node_list: TEndPoint = AINODE_TARGET_CONFIG_NODE_LIST # use for node management self._ainode_id = 0 diff --git a/iotdb-core/ainode/ainode/core/constant.py b/iotdb-core/ainode/ainode/core/constant.py index c0e021ceebf13..c307dbafe6396 100644 --- a/iotdb-core/ainode/ainode/core/constant.py +++ b/iotdb-core/ainode/ainode/core/constant.py @@ -21,6 +21,8 @@ from enum import Enum from typing import List +from ainode.thrift.common.ttypes import TEndPoint + AINODE_CONF_DIRECTORY_NAME = "conf" AINODE_ROOT_CONF_DIRECTORY_NAME = "conf" AINODE_CONF_FILE_NAME = "iotdb-ainode.properties" @@ -49,6 +51,7 @@ AINODE_CLUSTER_INGRESS_USERNAME = "root" AINODE_CLUSTER_INGRESS_PASSWORD = "root" AINODE_CLUSTER_INGRESS_TIME_ZONE = "UTC+8" +AINODE_TARGET_CONFIG_NODE_LIST = TEndPoint("127.0.0.1", 10710) # AINode log AINODE_LOG_FILE_NAMES = [ diff --git a/iotdb-core/ainode/ainode/core/handler.py b/iotdb-core/ainode/ainode/core/handler.py index 804be6379343e..524b80a88d8ab 100644 --- a/iotdb-core/ainode/ainode/core/handler.py +++ b/iotdb-core/ainode/ainode/core/handler.py @@ -30,6 +30,7 @@ TInferenceResp, TRegisterModelReq, TRegisterModelResp, + TShowModelsReq, TShowModelsResp, TTrainingReq, ) @@ -58,8 +59,8 @@ def forecast(self, req: TForecastReq) -> TSStatus: def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: return ClusterManager.get_heart_beat(req) - def showModels(self) -> TShowModelsResp: - return self._model_manager.show_models() + def showModels(self, req: TShowModelsReq) -> TShowModelsResp: + return self._model_manager.show_models(req) def createTrainingTask(self, req: TTrainingReq) -> TSStatus: pass diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/ainode/core/manager/inference_manager.py index 9d092c78c7cf9..9eda1c22651ba 100644 --- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py @@ -130,7 +130,7 @@ def _get_strategy(self, model_id, model): return TimerXLStrategy(model) if isinstance(model, SundialForPrediction): return SundialStrategy(model) - if self.model_manager.model_storage._is_built_in(model_id): + if self.model_manager.model_storage._is_built_in_or_fine_tuned(model_id): return BuiltInStrategy(model) return RegisteredStrategy(model) diff --git a/iotdb-core/ainode/ainode/core/manager/model_manager.py b/iotdb-core/ainode/ainode/core/manager/model_manager.py index 4688edf04c328..bb589a281bf63 100644 --- a/iotdb-core/ainode/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/model_manager.py @@ -33,6 +33,7 @@ TDeleteModelReq, TRegisterModelReq, TRegisterModelResp, + TShowModelsReq, TShowModelsResp, ) from ainode.thrift.common.ttypes import TSStatus @@ -55,19 +56,16 @@ def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp: ) except InvalidUriError as e: logger.warning(e) - self.model_storage.delete_model(req.modelId) return TRegisterModelResp( get_status(TSStatusCode.INVALID_URI_ERROR, e.message) ) except BadConfigValueError as e: logger.warning(e) - self.model_storage.delete_model(req.modelId) return TRegisterModelResp( get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message) ) except YAMLError as e: logger.warning(e) - self.model_storage.delete_model(req.modelId) if hasattr(e, "problem_mark"): mark = e.problem_mark return TRegisterModelResp( @@ -85,7 +83,6 @@ def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp: ) except Exception as e: logger.warning(e) - self.model_storage.delete_model(req.modelId) return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) def delete_model(self, req: TDeleteModelReq) -> TSStatus: @@ -141,8 +138,8 @@ def get_ckpt_path(self, model_id: str) -> str: """ return self.model_storage.get_ckpt_path(model_id) - def show_models(self) -> TShowModelsResp: - return self.model_storage.show_models() + def show_models(self, req: TShowModelsReq) -> TShowModelsResp: + return self.model_storage.show_models(req) def register_built_in_model(self, model_info: ModelInfo): self.model_storage.register_built_in_model(model_info) diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py b/iotdb-core/ainode/ainode/core/model/model_storage.py index 15727392e0bee..5a544cc0fea73 100644 --- a/iotdb-core/ainode/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/ainode/core/model/model_storage.py @@ -50,7 +50,7 @@ get_built_in_model_type, ) from ainode.core.util.lock import ModelLockPool -from ainode.thrift.ainode.ttypes import TShowModelsResp +from ainode.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp from ainode.thrift.common.ttypes import TSStatus logger = Logger() @@ -211,23 +211,30 @@ def register_model(self, model_id: str, uri: str): configs: TConfigs attributes: str """ - storage_path = os.path.join(self._model_dir, f"{model_id}") - # create storage dir if not exist - if not os.path.exists(storage_path): - os.makedirs(storage_path) - model_storage_path = os.path.join(storage_path, DEFAULT_MODEL_FILE_NAME) - config_storage_path = os.path.join(storage_path, DEFAULT_CONFIG_FILE_NAME) - configs, attributes = fetch_model_by_uri( - uri, model_storage_path, config_storage_path - ) - model_info = ModelInfo( - model_id=model_id, - model_type="", - category=ModelCategory.USER_DEFINED, - state=ModelStates.ACTIVE, - ) - self.register_built_in_model(model_info) - return configs, attributes + with self._lock_pool.get_lock(model_id).write_lock(): + storage_path = os.path.join(self._model_dir, f"{model_id}") + # create storage dir if not exist + if not os.path.exists(storage_path): + os.makedirs(storage_path) + model_storage_path = os.path.join(storage_path, DEFAULT_MODEL_FILE_NAME) + config_storage_path = os.path.join(storage_path, DEFAULT_CONFIG_FILE_NAME) + self._model_info_map[model_id] = ModelInfo( + model_id=model_id, + model_type="", + category=ModelCategory.USER_DEFINED, + state=ModelStates.LOADING, + ) + try: + # TODO: The uri should be fetched asynchronously + configs, attributes = fetch_model_by_uri( + uri, model_storage_path, config_storage_path + ) + self._model_info_map[model_id].state = ModelStates.ACTIVE + return configs, attributes + except Exception as e: + logger.error(f"Failed to register model {model_id}: {e}") + self._model_info_map[model_id].state = ModelStates.INACTIVE + raise e def delete_model(self, model_id: str) -> None: """ @@ -241,9 +248,12 @@ def delete_model(self, model_id: str) -> None: if self._is_built_in(model_id): raise BuiltInModelDeletionError(model_id) - # delete the user-defined model - storage_path = os.path.join(self._model_dir, f"{model_id}") + # delete the user-defined or fine-tuned model with self._lock_pool.get_lock(model_id).write_lock(): + storage_path = os.path.join(self._model_dir, f"{model_id}") + if os.path.exists(storage_path): + shutil.rmtree(storage_path) + storage_path = os.path.join(self._builtin_model_dir, f"{model_id}") if os.path.exists(storage_path): shutil.rmtree(storage_path) if model_id in self._model_info_map: @@ -260,6 +270,21 @@ def _is_built_in(self, model_id: str) -> bool: Returns: bool: True if the model is built-in, False otherwise. """ + return ( + model_id in self._model_info_map + and self._model_info_map[model_id].category == ModelCategory.BUILT_IN + ) + + def _is_built_in_or_fine_tuned(self, model_id: str) -> bool: + """ + Check if the model_id corresponds to a built-in or fine-tuned model. + + Args: + model_id (str): The ID of the model. + + Returns: + bool: True if the model is built-in or fine_tuned, False otherwise. + """ return model_id in self._model_info_map and ( self._model_info_map[model_id].category == ModelCategory.BUILT_IN or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED @@ -275,7 +300,7 @@ def load_model( model: The model instance corresponding to specific model_id """ with self._lock_pool.get_lock(model_id).read_lock(): - if self._is_built_in(model_id): + if self._is_built_in_or_fine_tuned(model_id): model_dir = os.path.join(self._builtin_model_dir, f"{model_id}") return fetch_built_in_model( get_built_in_model_type(self._model_info_map[model_id].model_type), @@ -312,7 +337,7 @@ def save_model(self, model_id: str, model: nn.Module): Whether saving succeeded """ with self._lock_pool.get_lock(model_id).write_lock(): - if self._is_built_in(model_id): + if self._is_built_in_or_fine_tuned(model_id): model_dir = os.path.join(self._builtin_model_dir, f"{model_id}") model.save_pretrained(model_dir) else: @@ -343,12 +368,31 @@ def get_ckpt_path(self, model_id: str) -> str: # Only support built-in models for now return os.path.join(self._builtin_model_dir, f"{model_id}") - def show_models(self) -> TShowModelsResp: + def show_models(self, req: TShowModelsReq) -> TShowModelsResp: + resp_status = TSStatus( + code=TSStatusCode.SUCCESS_STATUS.value, + message="Show models successfully", + ) + if req.modelId: + if req.modelId in self._model_info_map: + model_info = self._model_info_map[req.modelId] + return TShowModelsResp( + status=resp_status, + modelIdList=[req.modelId], + modelTypeMap={req.modelId: model_info.model_type}, + categoryMap={req.modelId: model_info.category.value}, + stateMap={req.modelId: model_info.state.value}, + ) + else: + return TShowModelsResp( + status=resp_status, + modelIdList=[], + modelTypeMap={}, + categoryMap={}, + stateMap={}, + ) return TShowModelsResp( - status=TSStatus( - code=TSStatusCode.SUCCESS_STATUS.value, - message="Show models successfully", - ), + status=resp_status, modelIdList=list(self._model_info_map.keys()), modelTypeMap=dict( (model_id, model_info.model_type) diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java index e8ee032d0e55c..ad1a6bed4eda7 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java @@ -19,6 +19,7 @@ package org.apache.iotdb.confignode.manager; +import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.common.rpc.thrift.TSStatus; @@ -105,7 +106,11 @@ public TShowModelResp showModel(final TShowModelReq req) { new TEndPoint(registeredAINode.getInternalAddress(), registeredAINode.getInternalPort()); try (AINodeClient client = AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) { - TShowModelsResp resp = client.showModels(); + TShowModelsReq showModelsReq = new TShowModelsReq(); + if (req.isSetModelId()) { + showModelsReq.setModelId(req.getModelId()); + } + TShowModelsResp resp = client.showModels(showModelsReq); TShowModelResp res = new TShowModelResp().setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); res.setModelIdList(resp.getModelIdList()); diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java index 5a8f832540319..23e02ea2e1d8a 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java @@ -19,6 +19,7 @@ package org.apache.iotdb.confignode.procedure.impl.model; +import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.commons.client.ainode.AINodeClient; import org.apache.iotdb.commons.client.ainode.AINodeClientManager; @@ -101,33 +102,35 @@ protected Flow executeFromState(ConfigNodeProcedureEnv env, DropModelState state private void dropModelOnAINode(ConfigNodeProcedureEnv env) { LOGGER.info("Start to drop model file [{}] on AI Node", modelName); - List nodeIds = - env.getConfigManager().getModelManager().getModelDistributions(modelName); - for (Integer nodeId : nodeIds) { - try (AINodeClient client = - AINodeClientManager.getInstance() - .borrowClient( - env.getConfigManager() - .getNodeManager() - .getRegisteredAINode(nodeId) - .getLocation() - .getInternalEndPoint())) { - TSStatus status = client.deleteModel(modelName); - if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - LOGGER.warn( - "Failed to drop model [{}] on AINode [{}], status: {}", - modelName, - nodeId, - status.getMessage()); - } - } catch (Exception e) { - LOGGER.warn( - "Failed to drop model [{}] on AINode [{}], status: {}", - modelName, - nodeId, - e.getMessage()); - } - } + List aiNodes = + env.getConfigManager().getNodeManager().getRegisteredAINodes(); + aiNodes.forEach( + aiNode -> { + int nodeId = aiNode.getLocation().getAiNodeId(); + try (AINodeClient client = + AINodeClientManager.getInstance() + .borrowClient( + env.getConfigManager() + .getNodeManager() + .getRegisteredAINode(nodeId) + .getLocation() + .getInternalEndPoint())) { + TSStatus status = client.deleteModel(modelName); + if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + LOGGER.warn( + "Failed to drop model [{}] on AINode [{}], status: {}", + modelName, + nodeId, + status.getMessage()); + } + } catch (Exception e) { + LOGGER.warn( + "Failed to drop model [{}] on AINode [{}], status: {}", + modelName, + nodeId, + e.getMessage()); + } + }); } private void dropModelOnConfigNode(ConfigNodeProcedureEnv env) { diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java index f573ea03b6d2c..e52310d1505a8 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java @@ -28,6 +28,7 @@ import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; import org.apache.iotdb.ainode.rpc.thrift.TWindowParams; @@ -159,9 +160,9 @@ public TSStatus deleteModel(String modelId) throws TException { } } - public TShowModelsResp showModels() throws TException { + public TShowModelsResp showModels(TShowModelsReq req) throws TException { try { - return client.showModels(); + return client.showModels(req); } catch (TException e) { logger.warn( "Failed to connect to AINode from ConfigNode when executing {}: {}", diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift index df6cf5daca312..a4ccef7e75263 100644 --- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift +++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift @@ -97,6 +97,10 @@ struct TForecastResp { 2: required binary forecastResult } +struct TShowModelsReq { + 1: optional string modelId +} + struct TShowModelsResp { 1: required common.TSStatus status 2: optional list modelIdList @@ -108,7 +112,7 @@ struct TShowModelsResp { service IAINodeRPCService { // -------------- For Config Node -------------- - TShowModelsResp showModels() + TShowModelsResp showModels(TShowModelsReq req) common.TSStatus deleteModel(TDeleteModelReq req)