diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java index b503d67f814ba..da4f8e9c53683 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java @@ -20,14 +20,10 @@ package org.apache.iotdb.ainode.it; import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; import java.io.File; import java.sql.Connection; @@ -40,8 +36,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) +// @RunWith(IoTDBTestRunner.class) +// @Category({AIClusterIT.class}) public class AINodeBasicIT { static final String MODEL_PATH = System.getProperty("user.dir") diff --git a/iotdb-core/ainode/ainode/core/constant.py b/iotdb-core/ainode/ainode/core/constant.py index 24d13a12ab8a9..98a66abc01995 100644 --- a/iotdb-core/ainode/ainode/core/constant.py +++ b/iotdb-core/ainode/ainode/core/constant.py @@ -53,8 +53,8 @@ TRIAL_ID_PREFIX = "__trial_" DEFAULT_TRIAL_ID = TRIAL_ID_PREFIX + "0" -DEFAULT_MODEL_FILE_NAME = "model.pt" -DEFAULT_CONFIG_FILE_NAME = "config.yaml" +DEFAULT_MODEL_FILE_NAME = "model.safetensors" +DEFAULT_CONFIG_FILE_NAME = "config.json" DEFAULT_CHUNK_SIZE = 8192 DEFAULT_RECONNECT_TIMEOUT = 20 diff --git a/iotdb-core/ainode/ainode/core/ingress/dataset.py b/iotdb-core/ainode/ainode/core/ingress/dataset.py index c2410ed4374d9..9783c6c85c1a3 100644 --- a/iotdb-core/ainode/ainode/core/ingress/dataset.py +++ b/iotdb-core/ainode/ainode/core/ingress/dataset.py @@ -15,10 +15,8 @@ # specific language governing permissions and limitations # under the License. # -from torch.utils.data import Dataset -from ainode.core.ingress.iotdb import IoTDBTableModelDataset, IoTDBTreeModelDataset -from ainode.core.util.decorator import singleton +from torch.utils.data import Dataset class BasicDatabaseDataset(Dataset): @@ -32,31 +30,3 @@ def __init__(self, ip: str, port: int, input_len: int, output_len: int): super().__init__(ip, port) self.input_len = input_len self.output_len = output_len - - -def register_dataset(key: str, dataset: Dataset): - DatasetFactory().register(key, dataset) - - -@singleton -class DatasetFactory(object): - - def __init__(self): - self.dataset_list = { - "iotdb.table": IoTDBTableModelDataset, - "iotdb.tree": IoTDBTreeModelDataset, - } - - def register(self, key: str, dataset: Dataset): - if key not in self.dataset_list: - self.dataset_list[key] = dataset - else: - raise KeyError(f"Dataset {key} already exists") - - def deregister(self, key: str): - del self.dataset_list[key] - - def get_dataset(self, key: str): - if key not in self.dataset_list.keys(): - raise KeyError(f"Dataset {key} does not exist") - return self.dataset_list[key] diff --git a/iotdb-core/ainode/ainode/core/ingress/iotdb.py b/iotdb-core/ainode/ainode/core/ingress/iotdb.py index 4b034ac880842..d1b344b43a8de 100644 --- a/iotdb-core/ainode/ainode/core/ingress/iotdb.py +++ b/iotdb-core/ainode/ainode/core/ingress/iotdb.py @@ -15,16 +15,19 @@ # specific language governing permissions and limitations # under the License. # +import numpy as np import torch from iotdb.Session import Session from iotdb.table_session import TableSession, TableSessionConfig from iotdb.utils.Field import Field from iotdb.utils.IoTDBConstants import TSDataType -from util.cache import MemoryLRUCache +from torch.utils.data import Dataset from ainode.core.config import AINodeDescriptor from ainode.core.ingress.dataset import BasicDatabaseForecastDataset from ainode.core.log import Logger +from ainode.core.util.cache import MemoryLRUCache +from ainode.core.util.decorator import singleton logger = Logger() @@ -55,7 +58,7 @@ def __init__( model_id: str, input_len: int, out_len: int, - schema_list: list, + data_schema_list: list, ip: str = "127.0.0.1", port: int = 6667, username: str = "root", @@ -81,15 +84,16 @@ def __init__( ) self.session.open(False) self.context_length = self.input_len + self.output_len - self._fetch_schema(schema_list) + self.token_num = self.context_length // self.input_len + self._fetch_schema(data_schema_list) self.start_idx = int(self.total_count * start_split) self.end_idx = int(self.total_count * end_split) self.cache_enable = _cache_enable() self.cache_key_prefix = model_id + "_" - def _fetch_schema(self, schema_list: list): + def _fetch_schema(self, data_schema_list: list): series_to_length = {} - for schema in schema_list: + for schema in data_schema_list: path_pattern = schema.schemaName series_list = [] time_condition = ( @@ -155,10 +159,13 @@ def __getitem__(self, index): if series_data is not None: series_data = torch.tensor(series_data) result = series_data[window_index : window_index + self.context_length] - return result[0 : self.input_len].unsqueeze(-1), result[ - -self.output_len : - ].unsqueeze(-1) + return ( + result[0 : self.input_len], + result[-self.output_len :], + np.ones(self.token_num, dtype=np.int32), + ) result = [] + sql = "" try: if self.cache_enable: sql = self.FETCH_SERIES_SQL % ( @@ -178,13 +185,15 @@ def __getitem__(self, index): while query_result.has_next(): result.append(get_field_value(query_result.next().get_fields()[0])) except Exception as e: - logger.error(e) + logger.error("Executing sql: {} with exception: {}".format(sql, e)) if self.cache_enable: self.cache.put(cache_key, result) result = torch.tensor(result) - return result[0 : self.input_len].unsqueeze(-1), result[ - -self.output_len : - ].unsqueeze(-1) + return ( + result[0 : self.input_len], + result[-self.output_len :], + np.ones(self.token_num, dtype=np.int32), + ) def __len__(self): return self.end_idx - self.start_idx @@ -228,9 +237,9 @@ def __init__( self.session = TableSession(table_session_config) self.context_length = self.input_len + self.output_len + self.token_num = self.context_length // self.input_len self._fetch_schema(data_schema_list) - v = self.total_count * start_split self.start_index = int(self.total_count * start_split) self.end_index = self.total_count * end_split @@ -285,19 +294,52 @@ def __getitem__(self, index): schema = series.split(".") result = [] + sql = self.FETCH_SERIES_SQL % ( + schema[0:1], + schema[2], + window_index, + self.context_length, + ) try: - with self.session.execute_query_statement( - self.FETCH_SERIES_SQL - % (schema[0:1], schema[2], window_index, self.context_length) - ) as query_result: + with self.session.execute_query_statement(sql) as query_result: while query_result.has_next(): result.append(get_field_value(query_result.next().get_fields()[0])) except Exception as e: - logger.error("Error happens when loading dataset str(e))") + logger.error("Executing sql: {} with exception: {}".format(sql, e)) result = torch.tensor(result) - return result[0 : self.input_len].unsqueeze(-1), result[ - -self.output_len : - ].unsqueeze(-1) + return ( + result[0 : self.input_len], + result[-self.output_len :], + np.ones(self.token_num, dtype=np.int32), + ) def __len__(self): return self.end_index - self.start_index + + +def register_dataset(key: str, dataset: Dataset): + DatasetFactory().register(key, dataset) + + +@singleton +class DatasetFactory(object): + + def __init__(self): + self.dataset_list = { + "iotdb.table": IoTDBTableModelDataset, + "iotdb.tree": IoTDBTreeModelDataset, + } + + def register(self, key: str, dataset: Dataset): + if key not in self.dataset_list: + self.dataset_list[key] = dataset + else: + raise KeyError(f"Dataset {key} already exists") + + def deregister(self, key: str): + del self.dataset_list[key] + + def get_dataset(self, key: str): + if key not in self.dataset_list.keys(): + raise KeyError(f"Dataset {key} does not exist") + return self.dataset_list[key] diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java index cbd67795774fa..fa228a233e152 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java @@ -27,6 +27,7 @@ import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.common.rpc.thrift.TFlushReq; import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet; import org.apache.iotdb.common.rpc.thrift.TSStatus; @@ -41,7 +42,6 @@ import org.apache.iotdb.commons.auth.entity.PrivilegeUnion; import org.apache.iotdb.commons.client.ainode.AINodeClient; import org.apache.iotdb.commons.client.ainode.AINodeClientManager; -import org.apache.iotdb.commons.client.ainode.AINodeInfo; import org.apache.iotdb.commons.cluster.NodeStatus; import org.apache.iotdb.commons.cluster.NodeType; import org.apache.iotdb.commons.conf.CommonConfig; @@ -136,6 +136,7 @@ import org.apache.iotdb.confignode.persistence.schema.ClusterSchemaInfo; import org.apache.iotdb.confignode.persistence.subscription.SubscriptionInfo; import org.apache.iotdb.confignode.procedure.impl.schema.SchemaUtils; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp; @@ -2671,7 +2672,7 @@ private List fetchSchemaForTableModel(TCreateTrainingReq req) { } } for (String tableName : dataSchemaForTable.getTableList()) { - dataSchemaList.add(new IDataSchema(dataSchemaForTable.curDatabase + DOT + tableName)); + dataSchemaList.add(new IDataSchema(tableName)); } return dataSchemaList; } @@ -2685,7 +2686,7 @@ public TSStatus createTraining(TCreateTrainingReq req) { TTrainingReq trainingReq = new TTrainingReq(); trainingReq.setModelId(req.getModelId()); - trainingReq.setModelType("timer_xl"); + trainingReq.setModelType("sundial"); if (req.existingModelId != null) { trainingReq.setExistingModelId(req.getExistingModelId()); } @@ -2710,8 +2711,11 @@ public TSStatus createTraining(TCreateTrainingReq req) { updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.TRAINING.ordinal())); trainingReq.setTargetDataSchema(dataSchema); + TAINodeInfo registeredAINode = getNodeManager().getRegisteredAINodeInfoList().get(0); + TEndPoint targetAINodeEndPoint = + new TEndPoint(registeredAINode.getInternalAddress(), registeredAINode.getInternalPort()); try (AINodeClient client = - AINodeClientManager.getInstance().borrowClient(AINodeInfo.endPoint)) { + AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) { status = client.createTrainingTask(trainingReq); if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { throw new IllegalArgumentException(status.message); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java index a169096341f98..95d9ecaa2788e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java @@ -1377,8 +1377,8 @@ protected IConfigTask visitCreateTraining(CreateTraining node, MPPQueryContext c node.isUseAllData(), node.getTargetTimeRanges(), node.getExistingModelId(), - node.getTargetDbs(), - tableList); + tableList, + node.getTargetDbs()); } @Override diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java index 84a6aa45f6d45..91d3258dba1ec 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java @@ -53,9 +53,6 @@ public CreateTrainingTask( String existingModelId, List targetTables, List targetDbs) { - if (!modelType.equalsIgnoreCase("timer_xl")) { - throw new UnsupportedOperationException("Only TimerXL model is supported now."); - } this.modelId = modelId; this.modelType = modelType; this.parameters = parameters; @@ -76,9 +73,6 @@ public CreateTrainingTask( List> timeRanges, String existingModelId, List targetPaths) { - if (!modelType.equalsIgnoreCase("timer_xl")) { - throw new UnsupportedOperationException("Only TimerXL model is supported now."); - } this.modelId = modelId; this.modelType = modelType; this.parameters = parameters; diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java index ae7a521a19a7e..346a459136ae3 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java @@ -199,7 +199,7 @@ public TForecastResp forecast( TSStatus tsStatus = new TSStatus(CAN_NOT_CONNECT_AINODE.getStatusCode()); tsStatus.setMessage( String.format( - "Failed to connect to AINode from DataNode when executing %s: %s", + "Failed to connect to AINode when executing %s: %s", Thread.currentThread().getStackTrace()[1].getMethodName(), e.getMessage())); return new TForecastResp(tsStatus, ByteBuffer.allocate(0)); } @@ -210,7 +210,7 @@ public TSStatus createTrainingTask(TTrainingReq req) throws TException { return client.createTrainingTask(req); } catch (TException e) { logger.warn( - "Failed to connect to AINode from DataNode when executing {}: {}", + "Failed to connect to AINode when executing {}: {}", Thread.currentThread().getStackTrace()[1].getMethodName(), e.getMessage()); throw new TException(MSG_CONNECTION_FAIL); diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java deleted file mode 100644 index d6f3a65527953..0000000000000 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.commons.client.ainode; - -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.conf.CommonDescriptor; - -public class AINodeInfo { - // currently, we only support one AINode - public static final TEndPoint endPoint = - CommonDescriptor.getInstance().getConfig().getTargetAINodeEndPoint(); -} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java index e82cc2fb1dd08..e5960a57c6b83 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java @@ -19,7 +19,6 @@ package org.apache.iotdb.commons.conf; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.commons.client.property.ClientPoolProperty.DefaultProperty; import org.apache.iotdb.commons.cluster.NodeStatus; import org.apache.iotdb.commons.enums.HandleSystemErrorStrategy; @@ -174,9 +173,6 @@ public class CommonConfig { /** Disk Monitor. */ private double diskSpaceWarningThreshold = 0.05; - /** Ip and port of target AI node. */ - private TEndPoint targetAINodeEndPoint = new TEndPoint("127.0.0.1", 10810); - /** Time partition origin in milliseconds. */ private long timePartitionOrigin = 0; @@ -662,14 +658,6 @@ public void setStatusReason(String statusReason) { this.statusReason = statusReason; } - public TEndPoint getTargetAINodeEndPoint() { - return targetAINodeEndPoint; - } - - public void setTargetAINodeEndPoint(TEndPoint targetAINodeEndPoint) { - this.targetAINodeEndPoint = targetAINodeEndPoint; - } - public int getTTimePartitionSlotTransmitLimit() { return TTimePartitionSlotTransmitLimit; } diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift index 3f4c42d058a58..f767f35d67c79 100644 --- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift +++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift @@ -1088,7 +1088,6 @@ struct TUpdateModelInfoReq { struct TDataSchemaForTable{ 1: required list databaseList 2: required list tableList - 3: required string curDatabase } struct TDataSchemaForTree{