Skip to content

Commit 7d5b7c5

Browse files
authored
[AINode] Enhance the robustness of AINode (#15695)
1 parent 761f4e7 commit 7d5b7c5

File tree

11 files changed

+80
-116
lines changed

11 files changed

+80
-116
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,10 @@
2020
package org.apache.iotdb.ainode.it;
2121

2222
import org.apache.iotdb.it.env.EnvFactory;
23-
import org.apache.iotdb.it.framework.IoTDBTestRunner;
24-
import org.apache.iotdb.itbase.category.AIClusterIT;
2523

2624
import org.junit.AfterClass;
2725
import org.junit.BeforeClass;
2826
import org.junit.Test;
29-
import org.junit.experimental.categories.Category;
30-
import org.junit.runner.RunWith;
3127

3228
import java.io.File;
3329
import java.sql.Connection;
@@ -40,8 +36,8 @@
4036
import static org.junit.Assert.assertEquals;
4137
import static org.junit.Assert.fail;
4238

43-
@RunWith(IoTDBTestRunner.class)
44-
@Category({AIClusterIT.class})
39+
// @RunWith(IoTDBTestRunner.class)
40+
// @Category({AIClusterIT.class})
4541
public class AINodeBasicIT {
4642
static final String MODEL_PATH =
4743
System.getProperty("user.dir")

iotdb-core/ainode/ainode/core/constant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353

5454
TRIAL_ID_PREFIX = "__trial_"
5555
DEFAULT_TRIAL_ID = TRIAL_ID_PREFIX + "0"
56-
DEFAULT_MODEL_FILE_NAME = "model.pt"
57-
DEFAULT_CONFIG_FILE_NAME = "config.yaml"
56+
DEFAULT_MODEL_FILE_NAME = "model.safetensors"
57+
DEFAULT_CONFIG_FILE_NAME = "config.json"
5858
DEFAULT_CHUNK_SIZE = 8192
5959

6060
DEFAULT_RECONNECT_TIMEOUT = 20

iotdb-core/ainode/ainode/core/ingress/dataset.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
#
18-
from torch.utils.data import Dataset
1918

20-
from ainode.core.ingress.iotdb import IoTDBTableModelDataset, IoTDBTreeModelDataset
21-
from ainode.core.util.decorator import singleton
19+
from torch.utils.data import Dataset
2220

2321

2422
class BasicDatabaseDataset(Dataset):
@@ -32,31 +30,3 @@ def __init__(self, ip: str, port: int, input_len: int, output_len: int):
3230
super().__init__(ip, port)
3331
self.input_len = input_len
3432
self.output_len = output_len
35-
36-
37-
def register_dataset(key: str, dataset: Dataset):
38-
DatasetFactory().register(key, dataset)
39-
40-
41-
@singleton
42-
class DatasetFactory(object):
43-
44-
def __init__(self):
45-
self.dataset_list = {
46-
"iotdb.table": IoTDBTableModelDataset,
47-
"iotdb.tree": IoTDBTreeModelDataset,
48-
}
49-
50-
def register(self, key: str, dataset: Dataset):
51-
if key not in self.dataset_list:
52-
self.dataset_list[key] = dataset
53-
else:
54-
raise KeyError(f"Dataset {key} already exists")
55-
56-
def deregister(self, key: str):
57-
del self.dataset_list[key]
58-
59-
def get_dataset(self, key: str):
60-
if key not in self.dataset_list.keys():
61-
raise KeyError(f"Dataset {key} does not exist")
62-
return self.dataset_list[key]

iotdb-core/ainode/ainode/core/ingress/iotdb.py

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,19 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
#
18+
import numpy as np
1819
import torch
1920
from iotdb.Session import Session
2021
from iotdb.table_session import TableSession, TableSessionConfig
2122
from iotdb.utils.Field import Field
2223
from iotdb.utils.IoTDBConstants import TSDataType
23-
from util.cache import MemoryLRUCache
24+
from torch.utils.data import Dataset
2425

2526
from ainode.core.config import AINodeDescriptor
2627
from ainode.core.ingress.dataset import BasicDatabaseForecastDataset
2728
from ainode.core.log import Logger
29+
from ainode.core.util.cache import MemoryLRUCache
30+
from ainode.core.util.decorator import singleton
2831

2932
logger = Logger()
3033

@@ -55,7 +58,7 @@ def __init__(
5558
model_id: str,
5659
input_len: int,
5760
out_len: int,
58-
schema_list: list,
61+
data_schema_list: list,
5962
ip: str = "127.0.0.1",
6063
port: int = 6667,
6164
username: str = "root",
@@ -81,15 +84,16 @@ def __init__(
8184
)
8285
self.session.open(False)
8386
self.context_length = self.input_len + self.output_len
84-
self._fetch_schema(schema_list)
87+
self.token_num = self.context_length // self.input_len
88+
self._fetch_schema(data_schema_list)
8589
self.start_idx = int(self.total_count * start_split)
8690
self.end_idx = int(self.total_count * end_split)
8791
self.cache_enable = _cache_enable()
8892
self.cache_key_prefix = model_id + "_"
8993

90-
def _fetch_schema(self, schema_list: list):
94+
def _fetch_schema(self, data_schema_list: list):
9195
series_to_length = {}
92-
for schema in schema_list:
96+
for schema in data_schema_list:
9397
path_pattern = schema.schemaName
9498
series_list = []
9599
time_condition = (
@@ -155,10 +159,13 @@ def __getitem__(self, index):
155159
if series_data is not None:
156160
series_data = torch.tensor(series_data)
157161
result = series_data[window_index : window_index + self.context_length]
158-
return result[0 : self.input_len].unsqueeze(-1), result[
159-
-self.output_len :
160-
].unsqueeze(-1)
162+
return (
163+
result[0 : self.input_len],
164+
result[-self.output_len :],
165+
np.ones(self.token_num, dtype=np.int32),
166+
)
161167
result = []
168+
sql = ""
162169
try:
163170
if self.cache_enable:
164171
sql = self.FETCH_SERIES_SQL % (
@@ -178,13 +185,15 @@ def __getitem__(self, index):
178185
while query_result.has_next():
179186
result.append(get_field_value(query_result.next().get_fields()[0]))
180187
except Exception as e:
181-
logger.error(e)
188+
logger.error("Executing sql: {} with exception: {}".format(sql, e))
182189
if self.cache_enable:
183190
self.cache.put(cache_key, result)
184191
result = torch.tensor(result)
185-
return result[0 : self.input_len].unsqueeze(-1), result[
186-
-self.output_len :
187-
].unsqueeze(-1)
192+
return (
193+
result[0 : self.input_len],
194+
result[-self.output_len :],
195+
np.ones(self.token_num, dtype=np.int32),
196+
)
188197

189198
def __len__(self):
190199
return self.end_idx - self.start_idx
@@ -228,9 +237,9 @@ def __init__(
228237

229238
self.session = TableSession(table_session_config)
230239
self.context_length = self.input_len + self.output_len
240+
self.token_num = self.context_length // self.input_len
231241
self._fetch_schema(data_schema_list)
232242

233-
v = self.total_count * start_split
234243
self.start_index = int(self.total_count * start_split)
235244
self.end_index = self.total_count * end_split
236245

@@ -285,19 +294,52 @@ def __getitem__(self, index):
285294
schema = series.split(".")
286295

287296
result = []
297+
sql = self.FETCH_SERIES_SQL % (
298+
schema[0:1],
299+
schema[2],
300+
window_index,
301+
self.context_length,
302+
)
288303
try:
289-
with self.session.execute_query_statement(
290-
self.FETCH_SERIES_SQL
291-
% (schema[0:1], schema[2], window_index, self.context_length)
292-
) as query_result:
304+
with self.session.execute_query_statement(sql) as query_result:
293305
while query_result.has_next():
294306
result.append(get_field_value(query_result.next().get_fields()[0]))
295307
except Exception as e:
296-
logger.error("Error happens when loading dataset str(e))")
308+
logger.error("Executing sql: {} with exception: {}".format(sql, e))
297309
result = torch.tensor(result)
298-
return result[0 : self.input_len].unsqueeze(-1), result[
299-
-self.output_len :
300-
].unsqueeze(-1)
310+
return (
311+
result[0 : self.input_len],
312+
result[-self.output_len :],
313+
np.ones(self.token_num, dtype=np.int32),
314+
)
301315

302316
def __len__(self):
303317
return self.end_index - self.start_index
318+
319+
320+
def register_dataset(key: str, dataset: Dataset):
321+
DatasetFactory().register(key, dataset)
322+
323+
324+
@singleton
325+
class DatasetFactory(object):
326+
327+
def __init__(self):
328+
self.dataset_list = {
329+
"iotdb.table": IoTDBTableModelDataset,
330+
"iotdb.tree": IoTDBTreeModelDataset,
331+
}
332+
333+
def register(self, key: str, dataset: Dataset):
334+
if key not in self.dataset_list:
335+
self.dataset_list[key] = dataset
336+
else:
337+
raise KeyError(f"Dataset {key} already exists")
338+
339+
def deregister(self, key: str):
340+
del self.dataset_list[key]
341+
342+
def get_dataset(self, key: str):
343+
if key not in self.dataset_list.keys():
344+
raise KeyError(f"Dataset {key} does not exist")
345+
return self.dataset_list[key]

iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId;
2828
import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration;
2929
import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation;
30+
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
3031
import org.apache.iotdb.common.rpc.thrift.TFlushReq;
3132
import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
3233
import org.apache.iotdb.common.rpc.thrift.TSStatus;
@@ -41,7 +42,6 @@
4142
import org.apache.iotdb.commons.auth.entity.PrivilegeUnion;
4243
import org.apache.iotdb.commons.client.ainode.AINodeClient;
4344
import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
44-
import org.apache.iotdb.commons.client.ainode.AINodeInfo;
4545
import org.apache.iotdb.commons.cluster.NodeStatus;
4646
import org.apache.iotdb.commons.cluster.NodeType;
4747
import org.apache.iotdb.commons.conf.CommonConfig;
@@ -136,6 +136,7 @@
136136
import org.apache.iotdb.confignode.persistence.schema.ClusterSchemaInfo;
137137
import org.apache.iotdb.confignode.persistence.subscription.SubscriptionInfo;
138138
import org.apache.iotdb.confignode.procedure.impl.schema.SchemaUtils;
139+
import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo;
139140
import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq;
140141
import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq;
141142
import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp;
@@ -2671,7 +2672,7 @@ private List<IDataSchema> fetchSchemaForTableModel(TCreateTrainingReq req) {
26712672
}
26722673
}
26732674
for (String tableName : dataSchemaForTable.getTableList()) {
2674-
dataSchemaList.add(new IDataSchema(dataSchemaForTable.curDatabase + DOT + tableName));
2675+
dataSchemaList.add(new IDataSchema(tableName));
26752676
}
26762677
return dataSchemaList;
26772678
}
@@ -2685,7 +2686,7 @@ public TSStatus createTraining(TCreateTrainingReq req) {
26852686

26862687
TTrainingReq trainingReq = new TTrainingReq();
26872688
trainingReq.setModelId(req.getModelId());
2688-
trainingReq.setModelType("timer_xl");
2689+
trainingReq.setModelType("sundial");
26892690
if (req.existingModelId != null) {
26902691
trainingReq.setExistingModelId(req.getExistingModelId());
26912692
}
@@ -2710,8 +2711,11 @@ public TSStatus createTraining(TCreateTrainingReq req) {
27102711
updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.TRAINING.ordinal()));
27112712
trainingReq.setTargetDataSchema(dataSchema);
27122713

2714+
TAINodeInfo registeredAINode = getNodeManager().getRegisteredAINodeInfoList().get(0);
2715+
TEndPoint targetAINodeEndPoint =
2716+
new TEndPoint(registeredAINode.getInternalAddress(), registeredAINode.getInternalPort());
27132717
try (AINodeClient client =
2714-
AINodeClientManager.getInstance().borrowClient(AINodeInfo.endPoint)) {
2718+
AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) {
27152719
status = client.createTrainingTask(trainingReq);
27162720
if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
27172721
throw new IllegalArgumentException(status.message);

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,8 +1377,8 @@ protected IConfigTask visitCreateTraining(CreateTraining node, MPPQueryContext c
13771377
node.isUseAllData(),
13781378
node.getTargetTimeRanges(),
13791379
node.getExistingModelId(),
1380-
node.getTargetDbs(),
1381-
tableList);
1380+
tableList,
1381+
node.getTargetDbs());
13821382
}
13831383

13841384
@Override

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@ public CreateTrainingTask(
5353
String existingModelId,
5454
List<String> targetTables,
5555
List<String> targetDbs) {
56-
if (!modelType.equalsIgnoreCase("timer_xl")) {
57-
throw new UnsupportedOperationException("Only TimerXL model is supported now.");
58-
}
5956
this.modelId = modelId;
6057
this.modelType = modelType;
6158
this.parameters = parameters;
@@ -76,9 +73,6 @@ public CreateTrainingTask(
7673
List<List<Long>> timeRanges,
7774
String existingModelId,
7875
List<String> targetPaths) {
79-
if (!modelType.equalsIgnoreCase("timer_xl")) {
80-
throw new UnsupportedOperationException("Only TimerXL model is supported now.");
81-
}
8276
this.modelId = modelId;
8377
this.modelType = modelType;
8478
this.parameters = parameters;

iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ public TForecastResp forecast(
199199
TSStatus tsStatus = new TSStatus(CAN_NOT_CONNECT_AINODE.getStatusCode());
200200
tsStatus.setMessage(
201201
String.format(
202-
"Failed to connect to AINode from DataNode when executing %s: %s",
202+
"Failed to connect to AINode when executing %s: %s",
203203
Thread.currentThread().getStackTrace()[1].getMethodName(), e.getMessage()));
204204
return new TForecastResp(tsStatus, ByteBuffer.allocate(0));
205205
}
@@ -210,7 +210,7 @@ public TSStatus createTrainingTask(TTrainingReq req) throws TException {
210210
return client.createTrainingTask(req);
211211
} catch (TException e) {
212212
logger.warn(
213-
"Failed to connect to AINode from DataNode when executing {}: {}",
213+
"Failed to connect to AINode when executing {}: {}",
214214
Thread.currentThread().getStackTrace()[1].getMethodName(),
215215
e.getMessage());
216216
throw new TException(MSG_CONNECTION_FAIL);

iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)