Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion iotdb-core/ainode/ainode/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
AINODE_ROOT_DIR,
AINODE_SYSTEM_DIR,
AINODE_SYSTEM_FILE_NAME,
AINODE_TARGET_CONFIG_NODE_LIST,
AINODE_THRIFT_COMPRESSION_ENABLED,
AINODE_VERSION_INFO,
)
Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(self):
self._ain_model_storage_cache_size = 30

# Target ConfigNode to be connected by AINode
self._ain_target_config_node_list: TEndPoint = TEndPoint("127.0.0.1", 10710)
self._ain_target_config_node_list: TEndPoint = AINODE_TARGET_CONFIG_NODE_LIST

# use for node management
self._ainode_id = 0
Expand Down
3 changes: 3 additions & 0 deletions iotdb-core/ainode/ainode/core/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from enum import Enum
from typing import List

from ainode.thrift.common.ttypes import TEndPoint

AINODE_CONF_DIRECTORY_NAME = "conf"
AINODE_ROOT_CONF_DIRECTORY_NAME = "conf"
AINODE_CONF_FILE_NAME = "iotdb-ainode.properties"
Expand Down Expand Up @@ -49,6 +51,7 @@
AINODE_CLUSTER_INGRESS_USERNAME = "root"
AINODE_CLUSTER_INGRESS_PASSWORD = "root"
AINODE_CLUSTER_INGRESS_TIME_ZONE = "UTC+8"
AINODE_TARGET_CONFIG_NODE_LIST = TEndPoint("127.0.0.1", 10710)

# AINode log
AINODE_LOG_FILE_NAMES = [
Expand Down
5 changes: 3 additions & 2 deletions iotdb-core/ainode/ainode/core/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TInferenceResp,
TRegisterModelReq,
TRegisterModelResp,
TShowModelsReq,
TShowModelsResp,
TTrainingReq,
)
Expand Down Expand Up @@ -58,8 +59,8 @@ def forecast(self, req: TForecastReq) -> TSStatus:
def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
return ClusterManager.get_heart_beat(req)

def showModels(self) -> TShowModelsResp:
return self._model_manager.show_models()
def showModels(self, req: TShowModelsReq) -> TShowModelsResp:
return self._model_manager.show_models(req)

def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
pass
2 changes: 1 addition & 1 deletion iotdb-core/ainode/ainode/core/manager/inference_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _get_strategy(self, model_id, model):
return TimerXLStrategy(model)
if isinstance(model, SundialForPrediction):
return SundialStrategy(model)
if self.model_manager.model_storage._is_built_in(model_id):
if self.model_manager.model_storage._is_built_in_or_fine_tuned(model_id):
return BuiltInStrategy(model)
return RegisteredStrategy(model)

Expand Down
9 changes: 3 additions & 6 deletions iotdb-core/ainode/ainode/core/manager/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
TDeleteModelReq,
TRegisterModelReq,
TRegisterModelResp,
TShowModelsReq,
TShowModelsResp,
)
from ainode.thrift.common.ttypes import TSStatus
Expand All @@ -55,19 +56,16 @@ def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp:
)
except InvalidUriError as e:
logger.warning(e)
self.model_storage.delete_model(req.modelId)
return TRegisterModelResp(
get_status(TSStatusCode.INVALID_URI_ERROR, e.message)
)
except BadConfigValueError as e:
logger.warning(e)
self.model_storage.delete_model(req.modelId)
return TRegisterModelResp(
get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message)
)
except YAMLError as e:
logger.warning(e)
self.model_storage.delete_model(req.modelId)
if hasattr(e, "problem_mark"):
mark = e.problem_mark
return TRegisterModelResp(
Expand All @@ -85,7 +83,6 @@ def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp:
)
except Exception as e:
logger.warning(e)
self.model_storage.delete_model(req.modelId)
return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR))

def delete_model(self, req: TDeleteModelReq) -> TSStatus:
Expand Down Expand Up @@ -141,8 +138,8 @@ def get_ckpt_path(self, model_id: str) -> str:
"""
return self.model_storage.get_ckpt_path(model_id)

def show_models(self) -> TShowModelsResp:
return self.model_storage.show_models()
def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
return self.model_storage.show_models(req)

def register_built_in_model(self, model_info: ModelInfo):
self.model_storage.register_built_in_model(model_info)
Expand Down
98 changes: 71 additions & 27 deletions iotdb-core/ainode/ainode/core/model/model_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
get_built_in_model_type,
)
from ainode.core.util.lock import ModelLockPool
from ainode.thrift.ainode.ttypes import TShowModelsResp
from ainode.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp
from ainode.thrift.common.ttypes import TSStatus

logger = Logger()
Expand Down Expand Up @@ -211,23 +211,30 @@ def register_model(self, model_id: str, uri: str):
configs: TConfigs
attributes: str
"""
storage_path = os.path.join(self._model_dir, f"{model_id}")
# create storage dir if not exist
if not os.path.exists(storage_path):
os.makedirs(storage_path)
model_storage_path = os.path.join(storage_path, DEFAULT_MODEL_FILE_NAME)
config_storage_path = os.path.join(storage_path, DEFAULT_CONFIG_FILE_NAME)
configs, attributes = fetch_model_by_uri(
uri, model_storage_path, config_storage_path
)
model_info = ModelInfo(
model_id=model_id,
model_type="",
category=ModelCategory.USER_DEFINED,
state=ModelStates.ACTIVE,
)
self.register_built_in_model(model_info)
return configs, attributes
with self._lock_pool.get_lock(model_id).write_lock():
storage_path = os.path.join(self._model_dir, f"{model_id}")
# create storage dir if not exist
if not os.path.exists(storage_path):
os.makedirs(storage_path)
model_storage_path = os.path.join(storage_path, DEFAULT_MODEL_FILE_NAME)
config_storage_path = os.path.join(storage_path, DEFAULT_CONFIG_FILE_NAME)
self._model_info_map[model_id] = ModelInfo(
model_id=model_id,
model_type="",
category=ModelCategory.USER_DEFINED,
state=ModelStates.LOADING,
)
try:
# TODO: The uri should be fetched asynchronously
configs, attributes = fetch_model_by_uri(
uri, model_storage_path, config_storage_path
)
self._model_info_map[model_id].state = ModelStates.ACTIVE
return configs, attributes
except Exception as e:
logger.error(f"Failed to register model {model_id}: {e}")
self._model_info_map[model_id].state = ModelStates.INACTIVE
raise e

def delete_model(self, model_id: str) -> None:
"""
Expand All @@ -241,9 +248,12 @@ def delete_model(self, model_id: str) -> None:
if self._is_built_in(model_id):
raise BuiltInModelDeletionError(model_id)

# delete the user-defined model
storage_path = os.path.join(self._model_dir, f"{model_id}")
# delete the user-defined or fine-tuned model
with self._lock_pool.get_lock(model_id).write_lock():
storage_path = os.path.join(self._model_dir, f"{model_id}")
if os.path.exists(storage_path):
shutil.rmtree(storage_path)
storage_path = os.path.join(self._builtin_model_dir, f"{model_id}")
if os.path.exists(storage_path):
shutil.rmtree(storage_path)
if model_id in self._model_info_map:
Expand All @@ -260,6 +270,21 @@ def _is_built_in(self, model_id: str) -> bool:
Returns:
bool: True if the model is built-in, False otherwise.
"""
return (
model_id in self._model_info_map
and self._model_info_map[model_id].category == ModelCategory.BUILT_IN
)

def _is_built_in_or_fine_tuned(self, model_id: str) -> bool:
"""
Check if the model_id corresponds to a built-in or fine-tuned model.

Args:
model_id (str): The ID of the model.

Returns:
bool: True if the model is built-in or fine_tuned, False otherwise.
"""
return model_id in self._model_info_map and (
self._model_info_map[model_id].category == ModelCategory.BUILT_IN
or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED
Expand All @@ -275,7 +300,7 @@ def load_model(
model: The model instance corresponding to specific model_id
"""
with self._lock_pool.get_lock(model_id).read_lock():
if self._is_built_in(model_id):
if self._is_built_in_or_fine_tuned(model_id):
model_dir = os.path.join(self._builtin_model_dir, f"{model_id}")
return fetch_built_in_model(
get_built_in_model_type(self._model_info_map[model_id].model_type),
Expand Down Expand Up @@ -312,7 +337,7 @@ def save_model(self, model_id: str, model: nn.Module):
Whether saving succeeded
"""
with self._lock_pool.get_lock(model_id).write_lock():
if self._is_built_in(model_id):
if self._is_built_in_or_fine_tuned(model_id):
model_dir = os.path.join(self._builtin_model_dir, f"{model_id}")
model.save_pretrained(model_dir)
else:
Expand Down Expand Up @@ -343,12 +368,31 @@ def get_ckpt_path(self, model_id: str) -> str:
# Only support built-in models for now
return os.path.join(self._builtin_model_dir, f"{model_id}")

def show_models(self) -> TShowModelsResp:
def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
resp_status = TSStatus(
code=TSStatusCode.SUCCESS_STATUS.value,
message="Show models successfully",
)
if req.modelId:
if req.modelId in self._model_info_map:
model_info = self._model_info_map[req.modelId]
return TShowModelsResp(
status=resp_status,
modelIdList=[req.modelId],
modelTypeMap={req.modelId: model_info.model_type},
categoryMap={req.modelId: model_info.category.value},
stateMap={req.modelId: model_info.state.value},
)
else:
return TShowModelsResp(
status=resp_status,
modelIdList=[],
modelTypeMap={},
categoryMap={},
stateMap={},
)
return TShowModelsResp(
status=TSStatus(
code=TSStatusCode.SUCCESS_STATUS.value,
message="Show models successfully",
),
status=resp_status,
modelIdList=list(self._model_info_map.keys()),
modelTypeMap=dict(
(model_id, model_info.model_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.iotdb.confignode.manager;

import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
Expand Down Expand Up @@ -105,7 +106,11 @@ public TShowModelResp showModel(final TShowModelReq req) {
new TEndPoint(registeredAINode.getInternalAddress(), registeredAINode.getInternalPort());
try (AINodeClient client =
AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) {
TShowModelsResp resp = client.showModels();
TShowModelsReq showModelsReq = new TShowModelsReq();
if (req.isSetModelId()) {
showModelsReq.setModelId(req.getModelId());
}
TShowModelsResp resp = client.showModels(showModelsReq);
TShowModelResp res =
new TShowModelResp().setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
res.setModelIdList(resp.getModelIdList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.iotdb.confignode.procedure.impl.model;

import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.client.ainode.AINodeClient;
import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
Expand Down Expand Up @@ -101,33 +102,35 @@ protected Flow executeFromState(ConfigNodeProcedureEnv env, DropModelState state
private void dropModelOnAINode(ConfigNodeProcedureEnv env) {
LOGGER.info("Start to drop model file [{}] on AI Node", modelName);

List<Integer> nodeIds =
env.getConfigManager().getModelManager().getModelDistributions(modelName);
for (Integer nodeId : nodeIds) {
try (AINodeClient client =
AINodeClientManager.getInstance()
.borrowClient(
env.getConfigManager()
.getNodeManager()
.getRegisteredAINode(nodeId)
.getLocation()
.getInternalEndPoint())) {
TSStatus status = client.deleteModel(modelName);
if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
LOGGER.warn(
"Failed to drop model [{}] on AINode [{}], status: {}",
modelName,
nodeId,
status.getMessage());
}
} catch (Exception e) {
LOGGER.warn(
"Failed to drop model [{}] on AINode [{}], status: {}",
modelName,
nodeId,
e.getMessage());
}
}
List<TAINodeConfiguration> aiNodes =
env.getConfigManager().getNodeManager().getRegisteredAINodes();
aiNodes.forEach(
aiNode -> {
int nodeId = aiNode.getLocation().getAiNodeId();
try (AINodeClient client =
AINodeClientManager.getInstance()
.borrowClient(
env.getConfigManager()
.getNodeManager()
.getRegisteredAINode(nodeId)
.getLocation()
.getInternalEndPoint())) {
TSStatus status = client.deleteModel(modelName);
if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
LOGGER.warn(
"Failed to drop model [{}] on AINode [{}], status: {}",
modelName,
nodeId,
status.getMessage());
}
} catch (Exception e) {
LOGGER.warn(
"Failed to drop model [{}] on AINode [{}], status: {}",
modelName,
nodeId,
e.getMessage());
}
});
}

private void dropModelOnConfigNode(ConfigNodeProcedureEnv env) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq;
import org.apache.iotdb.ainode.rpc.thrift.TWindowParams;
Expand Down Expand Up @@ -159,9 +160,9 @@ public TSStatus deleteModel(String modelId) throws TException {
}
}

public TShowModelsResp showModels() throws TException {
public TShowModelsResp showModels(TShowModelsReq req) throws TException {
try {
return client.showModels();
return client.showModels(req);
} catch (TException e) {
logger.warn(
"Failed to connect to AINode from ConfigNode when executing {}: {}",
Expand Down
6 changes: 5 additions & 1 deletion iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ struct TForecastResp {
2: required binary forecastResult
}

struct TShowModelsReq {
1: optional string modelId
}

struct TShowModelsResp {
1: required common.TSStatus status
2: optional list<string> modelIdList
Expand All @@ -108,7 +112,7 @@ struct TShowModelsResp {
service IAINodeRPCService {

// -------------- For Config Node --------------
TShowModelsResp showModels()
TShowModelsResp showModels(TShowModelsReq req)

common.TSStatus deleteModel(TDeleteModelReq req)

Expand Down
Loading