Skip to content

Commit 0a0ec3f

Browse files
authored
[AINode] Fix model management bugs (#15890)
1 parent a10e213 commit 0a0ec3f

File tree

10 files changed

+127
-68
lines changed

10 files changed

+127
-68
lines changed

iotdb-core/ainode/ainode/core/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
AINODE_ROOT_DIR,
3939
AINODE_SYSTEM_DIR,
4040
AINODE_SYSTEM_FILE_NAME,
41+
AINODE_TARGET_CONFIG_NODE_LIST,
4142
AINODE_THRIFT_COMPRESSION_ENABLED,
4243
AINODE_VERSION_INFO,
4344
)
@@ -73,7 +74,7 @@ def __init__(self):
7374
self._ain_model_storage_cache_size = 30
7475

7576
# Target ConfigNode to be connected by AINode
76-
self._ain_target_config_node_list: TEndPoint = TEndPoint("127.0.0.1", 10710)
77+
self._ain_target_config_node_list: TEndPoint = AINODE_TARGET_CONFIG_NODE_LIST
7778

7879
# use for node management
7980
self._ainode_id = 0

iotdb-core/ainode/ainode/core/constant.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from enum import Enum
2222
from typing import List
2323

24+
from ainode.thrift.common.ttypes import TEndPoint
25+
2426
AINODE_CONF_DIRECTORY_NAME = "conf"
2527
AINODE_ROOT_CONF_DIRECTORY_NAME = "conf"
2628
AINODE_CONF_FILE_NAME = "iotdb-ainode.properties"
@@ -49,6 +51,7 @@
4951
AINODE_CLUSTER_INGRESS_USERNAME = "root"
5052
AINODE_CLUSTER_INGRESS_PASSWORD = "root"
5153
AINODE_CLUSTER_INGRESS_TIME_ZONE = "UTC+8"
54+
AINODE_TARGET_CONFIG_NODE_LIST = TEndPoint("127.0.0.1", 10710)
5255

5356
# AINode log
5457
AINODE_LOG_FILE_NAMES = [

iotdb-core/ainode/ainode/core/handler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
TInferenceResp,
3131
TRegisterModelReq,
3232
TRegisterModelResp,
33+
TShowModelsReq,
3334
TShowModelsResp,
3435
TTrainingReq,
3536
)
@@ -58,8 +59,8 @@ def forecast(self, req: TForecastReq) -> TSStatus:
5859
def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
5960
return ClusterManager.get_heart_beat(req)
6061

61-
def showModels(self) -> TShowModelsResp:
62-
return self._model_manager.show_models()
62+
def showModels(self, req: TShowModelsReq) -> TShowModelsResp:
63+
return self._model_manager.show_models(req)
6364

6465
def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
6566
pass

iotdb-core/ainode/ainode/core/manager/inference_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _get_strategy(self, model_id, model):
130130
return TimerXLStrategy(model)
131131
if isinstance(model, SundialForPrediction):
132132
return SundialStrategy(model)
133-
if self.model_manager.model_storage._is_built_in(model_id):
133+
if self.model_manager.model_storage._is_built_in_or_fine_tuned(model_id):
134134
return BuiltInStrategy(model)
135135
return RegisteredStrategy(model)
136136

iotdb-core/ainode/ainode/core/manager/model_manager.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
TDeleteModelReq,
3434
TRegisterModelReq,
3535
TRegisterModelResp,
36+
TShowModelsReq,
3637
TShowModelsResp,
3738
)
3839
from ainode.thrift.common.ttypes import TSStatus
@@ -55,19 +56,16 @@ def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp:
5556
)
5657
except InvalidUriError as e:
5758
logger.warning(e)
58-
self.model_storage.delete_model(req.modelId)
5959
return TRegisterModelResp(
6060
get_status(TSStatusCode.INVALID_URI_ERROR, e.message)
6161
)
6262
except BadConfigValueError as e:
6363
logger.warning(e)
64-
self.model_storage.delete_model(req.modelId)
6564
return TRegisterModelResp(
6665
get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message)
6766
)
6867
except YAMLError as e:
6968
logger.warning(e)
70-
self.model_storage.delete_model(req.modelId)
7169
if hasattr(e, "problem_mark"):
7270
mark = e.problem_mark
7371
return TRegisterModelResp(
@@ -85,7 +83,6 @@ def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp:
8583
)
8684
except Exception as e:
8785
logger.warning(e)
88-
self.model_storage.delete_model(req.modelId)
8986
return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR))
9087

9188
def delete_model(self, req: TDeleteModelReq) -> TSStatus:
@@ -141,8 +138,8 @@ def get_ckpt_path(self, model_id: str) -> str:
141138
"""
142139
return self.model_storage.get_ckpt_path(model_id)
143140

144-
def show_models(self) -> TShowModelsResp:
145-
return self.model_storage.show_models()
141+
def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
142+
return self.model_storage.show_models(req)
146143

147144
def register_built_in_model(self, model_info: ModelInfo):
148145
self.model_storage.register_built_in_model(model_info)

iotdb-core/ainode/ainode/core/model/model_storage.py

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
get_built_in_model_type,
5151
)
5252
from ainode.core.util.lock import ModelLockPool
53-
from ainode.thrift.ainode.ttypes import TShowModelsResp
53+
from ainode.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp
5454
from ainode.thrift.common.ttypes import TSStatus
5555

5656
logger = Logger()
@@ -211,23 +211,30 @@ def register_model(self, model_id: str, uri: str):
211211
configs: TConfigs
212212
attributes: str
213213
"""
214-
storage_path = os.path.join(self._model_dir, f"{model_id}")
215-
# create storage dir if not exist
216-
if not os.path.exists(storage_path):
217-
os.makedirs(storage_path)
218-
model_storage_path = os.path.join(storage_path, DEFAULT_MODEL_FILE_NAME)
219-
config_storage_path = os.path.join(storage_path, DEFAULT_CONFIG_FILE_NAME)
220-
configs, attributes = fetch_model_by_uri(
221-
uri, model_storage_path, config_storage_path
222-
)
223-
model_info = ModelInfo(
224-
model_id=model_id,
225-
model_type="",
226-
category=ModelCategory.USER_DEFINED,
227-
state=ModelStates.ACTIVE,
228-
)
229-
self.register_built_in_model(model_info)
230-
return configs, attributes
214+
with self._lock_pool.get_lock(model_id).write_lock():
215+
storage_path = os.path.join(self._model_dir, f"{model_id}")
216+
# create storage dir if not exist
217+
if not os.path.exists(storage_path):
218+
os.makedirs(storage_path)
219+
model_storage_path = os.path.join(storage_path, DEFAULT_MODEL_FILE_NAME)
220+
config_storage_path = os.path.join(storage_path, DEFAULT_CONFIG_FILE_NAME)
221+
self._model_info_map[model_id] = ModelInfo(
222+
model_id=model_id,
223+
model_type="",
224+
category=ModelCategory.USER_DEFINED,
225+
state=ModelStates.LOADING,
226+
)
227+
try:
228+
# TODO: The uri should be fetched asynchronously
229+
configs, attributes = fetch_model_by_uri(
230+
uri, model_storage_path, config_storage_path
231+
)
232+
self._model_info_map[model_id].state = ModelStates.ACTIVE
233+
return configs, attributes
234+
except Exception as e:
235+
logger.error(f"Failed to register model {model_id}: {e}")
236+
self._model_info_map[model_id].state = ModelStates.INACTIVE
237+
raise e
231238

232239
def delete_model(self, model_id: str) -> None:
233240
"""
@@ -241,9 +248,12 @@ def delete_model(self, model_id: str) -> None:
241248
if self._is_built_in(model_id):
242249
raise BuiltInModelDeletionError(model_id)
243250

244-
# delete the user-defined model
245-
storage_path = os.path.join(self._model_dir, f"{model_id}")
251+
# delete the user-defined or fine-tuned model
246252
with self._lock_pool.get_lock(model_id).write_lock():
253+
storage_path = os.path.join(self._model_dir, f"{model_id}")
254+
if os.path.exists(storage_path):
255+
shutil.rmtree(storage_path)
256+
storage_path = os.path.join(self._builtin_model_dir, f"{model_id}")
247257
if os.path.exists(storage_path):
248258
shutil.rmtree(storage_path)
249259
if model_id in self._model_info_map:
@@ -260,6 +270,21 @@ def _is_built_in(self, model_id: str) -> bool:
260270
Returns:
261271
bool: True if the model is built-in, False otherwise.
262272
"""
273+
return (
274+
model_id in self._model_info_map
275+
and self._model_info_map[model_id].category == ModelCategory.BUILT_IN
276+
)
277+
278+
def _is_built_in_or_fine_tuned(self, model_id: str) -> bool:
279+
"""
280+
Check if the model_id corresponds to a built-in or fine-tuned model.
281+
282+
Args:
283+
model_id (str): The ID of the model.
284+
285+
Returns:
286+
bool: True if the model is built-in or fine_tuned, False otherwise.
287+
"""
263288
return model_id in self._model_info_map and (
264289
self._model_info_map[model_id].category == ModelCategory.BUILT_IN
265290
or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED
@@ -275,7 +300,7 @@ def load_model(
275300
model: The model instance corresponding to specific model_id
276301
"""
277302
with self._lock_pool.get_lock(model_id).read_lock():
278-
if self._is_built_in(model_id):
303+
if self._is_built_in_or_fine_tuned(model_id):
279304
model_dir = os.path.join(self._builtin_model_dir, f"{model_id}")
280305
return fetch_built_in_model(
281306
get_built_in_model_type(self._model_info_map[model_id].model_type),
@@ -312,7 +337,7 @@ def save_model(self, model_id: str, model: nn.Module):
312337
Whether saving succeeded
313338
"""
314339
with self._lock_pool.get_lock(model_id).write_lock():
315-
if self._is_built_in(model_id):
340+
if self._is_built_in_or_fine_tuned(model_id):
316341
model_dir = os.path.join(self._builtin_model_dir, f"{model_id}")
317342
model.save_pretrained(model_dir)
318343
else:
@@ -343,12 +368,31 @@ def get_ckpt_path(self, model_id: str) -> str:
343368
# Only support built-in models for now
344369
return os.path.join(self._builtin_model_dir, f"{model_id}")
345370

346-
def show_models(self) -> TShowModelsResp:
371+
def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
372+
resp_status = TSStatus(
373+
code=TSStatusCode.SUCCESS_STATUS.value,
374+
message="Show models successfully",
375+
)
376+
if req.modelId:
377+
if req.modelId in self._model_info_map:
378+
model_info = self._model_info_map[req.modelId]
379+
return TShowModelsResp(
380+
status=resp_status,
381+
modelIdList=[req.modelId],
382+
modelTypeMap={req.modelId: model_info.model_type},
383+
categoryMap={req.modelId: model_info.category.value},
384+
stateMap={req.modelId: model_info.state.value},
385+
)
386+
else:
387+
return TShowModelsResp(
388+
status=resp_status,
389+
modelIdList=[],
390+
modelTypeMap={},
391+
categoryMap={},
392+
stateMap={},
393+
)
347394
return TShowModelsResp(
348-
status=TSStatus(
349-
code=TSStatusCode.SUCCESS_STATUS.value,
350-
message="Show models successfully",
351-
),
395+
status=resp_status,
352396
modelIdList=list(self._model_info_map.keys()),
353397
modelTypeMap=dict(
354398
(model_id, model_info.model_type)

iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.iotdb.confignode.manager;
2121

22+
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
2223
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
2324
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
2425
import org.apache.iotdb.common.rpc.thrift.TSStatus;
@@ -105,7 +106,11 @@ public TShowModelResp showModel(final TShowModelReq req) {
105106
new TEndPoint(registeredAINode.getInternalAddress(), registeredAINode.getInternalPort());
106107
try (AINodeClient client =
107108
AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) {
108-
TShowModelsResp resp = client.showModels();
109+
TShowModelsReq showModelsReq = new TShowModelsReq();
110+
if (req.isSetModelId()) {
111+
showModelsReq.setModelId(req.getModelId());
112+
}
113+
TShowModelsResp resp = client.showModels(showModelsReq);
109114
TShowModelResp res =
110115
new TShowModelResp().setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
111116
res.setModelIdList(resp.getModelIdList());

iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.iotdb.confignode.procedure.impl.model;
2121

22+
import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration;
2223
import org.apache.iotdb.common.rpc.thrift.TSStatus;
2324
import org.apache.iotdb.commons.client.ainode.AINodeClient;
2425
import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
@@ -101,33 +102,35 @@ protected Flow executeFromState(ConfigNodeProcedureEnv env, DropModelState state
101102
private void dropModelOnAINode(ConfigNodeProcedureEnv env) {
102103
LOGGER.info("Start to drop model file [{}] on AI Node", modelName);
103104

104-
List<Integer> nodeIds =
105-
env.getConfigManager().getModelManager().getModelDistributions(modelName);
106-
for (Integer nodeId : nodeIds) {
107-
try (AINodeClient client =
108-
AINodeClientManager.getInstance()
109-
.borrowClient(
110-
env.getConfigManager()
111-
.getNodeManager()
112-
.getRegisteredAINode(nodeId)
113-
.getLocation()
114-
.getInternalEndPoint())) {
115-
TSStatus status = client.deleteModel(modelName);
116-
if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
117-
LOGGER.warn(
118-
"Failed to drop model [{}] on AINode [{}], status: {}",
119-
modelName,
120-
nodeId,
121-
status.getMessage());
122-
}
123-
} catch (Exception e) {
124-
LOGGER.warn(
125-
"Failed to drop model [{}] on AINode [{}], status: {}",
126-
modelName,
127-
nodeId,
128-
e.getMessage());
129-
}
130-
}
105+
List<TAINodeConfiguration> aiNodes =
106+
env.getConfigManager().getNodeManager().getRegisteredAINodes();
107+
aiNodes.forEach(
108+
aiNode -> {
109+
int nodeId = aiNode.getLocation().getAiNodeId();
110+
try (AINodeClient client =
111+
AINodeClientManager.getInstance()
112+
.borrowClient(
113+
env.getConfigManager()
114+
.getNodeManager()
115+
.getRegisteredAINode(nodeId)
116+
.getLocation()
117+
.getInternalEndPoint())) {
118+
TSStatus status = client.deleteModel(modelName);
119+
if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
120+
LOGGER.warn(
121+
"Failed to drop model [{}] on AINode [{}], status: {}",
122+
modelName,
123+
nodeId,
124+
status.getMessage());
125+
}
126+
} catch (Exception e) {
127+
LOGGER.warn(
128+
"Failed to drop model [{}] on AINode [{}], status: {}",
129+
modelName,
130+
nodeId,
131+
e.getMessage());
132+
}
133+
});
131134
}
132135

133136
private void dropModelOnConfigNode(ConfigNodeProcedureEnv env) {

iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp;
2929
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq;
3030
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp;
31+
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
3132
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
3233
import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq;
3334
import org.apache.iotdb.ainode.rpc.thrift.TWindowParams;
@@ -159,9 +160,9 @@ public TSStatus deleteModel(String modelId) throws TException {
159160
}
160161
}
161162

162-
public TShowModelsResp showModels() throws TException {
163+
public TShowModelsResp showModels(TShowModelsReq req) throws TException {
163164
try {
164-
return client.showModels();
165+
return client.showModels(req);
165166
} catch (TException e) {
166167
logger.warn(
167168
"Failed to connect to AINode from ConfigNode when executing {}: {}",

iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ struct TForecastResp {
9797
2: required binary forecastResult
9898
}
9999

100+
struct TShowModelsReq {
101+
1: optional string modelId
102+
}
103+
100104
struct TShowModelsResp {
101105
1: required common.TSStatus status
102106
2: optional list<string> modelIdList
@@ -108,7 +112,7 @@ struct TShowModelsResp {
108112
service IAINodeRPCService {
109113

110114
// -------------- For Config Node --------------
111-
TShowModelsResp showModels()
115+
TShowModelsResp showModels(TShowModelsReq req)
112116

113117
common.TSStatus deleteModel(TDeleteModelReq req)
114118

0 commit comments

Comments
 (0)