Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,10 @@
package org.apache.iotdb.ainode.it;

import org.apache.iotdb.it.env.EnvFactory;
import org.apache.iotdb.it.framework.IoTDBTestRunner;
import org.apache.iotdb.itbase.category.AIClusterIT;

import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;

import java.io.File;
import java.sql.Connection;
Expand All @@ -40,8 +36,8 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

@RunWith(IoTDBTestRunner.class)
@Category({AIClusterIT.class})
// @RunWith(IoTDBTestRunner.class)
// @Category({AIClusterIT.class})
public class AINodeBasicIT {
static final String MODEL_PATH =
System.getProperty("user.dir")
Expand Down
4 changes: 2 additions & 2 deletions iotdb-core/ainode/ainode/core/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@

TRIAL_ID_PREFIX = "__trial_"
DEFAULT_TRIAL_ID = TRIAL_ID_PREFIX + "0"
DEFAULT_MODEL_FILE_NAME = "model.pt"
DEFAULT_CONFIG_FILE_NAME = "config.yaml"
DEFAULT_MODEL_FILE_NAME = "model.safetensors"
DEFAULT_CONFIG_FILE_NAME = "config.json"
DEFAULT_CHUNK_SIZE = 8192

DEFAULT_RECONNECT_TIMEOUT = 20
Expand Down
32 changes: 1 addition & 31 deletions iotdb-core/ainode/ainode/core/ingress/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
# specific language governing permissions and limitations
# under the License.
#
from torch.utils.data import Dataset

from ainode.core.ingress.iotdb import IoTDBTableModelDataset, IoTDBTreeModelDataset
from ainode.core.util.decorator import singleton
from torch.utils.data import Dataset


class BasicDatabaseDataset(Dataset):
Expand All @@ -32,31 +30,3 @@ def __init__(self, ip: str, port: int, input_len: int, output_len: int):
super().__init__(ip, port)
self.input_len = input_len
self.output_len = output_len


def register_dataset(key: str, dataset: Dataset):
DatasetFactory().register(key, dataset)


@singleton
class DatasetFactory(object):

def __init__(self):
self.dataset_list = {
"iotdb.table": IoTDBTableModelDataset,
"iotdb.tree": IoTDBTreeModelDataset,
}

def register(self, key: str, dataset: Dataset):
if key not in self.dataset_list:
self.dataset_list[key] = dataset
else:
raise KeyError(f"Dataset {key} already exists")

def deregister(self, key: str):
del self.dataset_list[key]

def get_dataset(self, key: str):
if key not in self.dataset_list.keys():
raise KeyError(f"Dataset {key} does not exist")
return self.dataset_list[key]
84 changes: 63 additions & 21 deletions iotdb-core/ainode/ainode/core/ingress/iotdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@
# specific language governing permissions and limitations
# under the License.
#
import numpy as np
import torch
from iotdb.Session import Session
from iotdb.table_session import TableSession, TableSessionConfig
from iotdb.utils.Field import Field
from iotdb.utils.IoTDBConstants import TSDataType
from util.cache import MemoryLRUCache
from torch.utils.data import Dataset

from ainode.core.config import AINodeDescriptor
from ainode.core.ingress.dataset import BasicDatabaseForecastDataset
from ainode.core.log import Logger
from ainode.core.util.cache import MemoryLRUCache
from ainode.core.util.decorator import singleton

logger = Logger()

Expand Down Expand Up @@ -55,7 +58,7 @@ def __init__(
model_id: str,
input_len: int,
out_len: int,
schema_list: list,
data_schema_list: list,
ip: str = "127.0.0.1",
port: int = 6667,
username: str = "root",
Expand All @@ -81,15 +84,16 @@ def __init__(
)
self.session.open(False)
self.context_length = self.input_len + self.output_len
self._fetch_schema(schema_list)
self.token_num = self.context_length // self.input_len
self._fetch_schema(data_schema_list)
self.start_idx = int(self.total_count * start_split)
self.end_idx = int(self.total_count * end_split)
self.cache_enable = _cache_enable()
self.cache_key_prefix = model_id + "_"

def _fetch_schema(self, schema_list: list):
def _fetch_schema(self, data_schema_list: list):
series_to_length = {}
for schema in schema_list:
for schema in data_schema_list:
path_pattern = schema.schemaName
series_list = []
time_condition = (
Expand Down Expand Up @@ -155,10 +159,13 @@ def __getitem__(self, index):
if series_data is not None:
series_data = torch.tensor(series_data)
result = series_data[window_index : window_index + self.context_length]
return result[0 : self.input_len].unsqueeze(-1), result[
-self.output_len :
].unsqueeze(-1)
return (
result[0 : self.input_len],
result[-self.output_len :],
np.ones(self.token_num, dtype=np.int32),
)
result = []
sql = ""
try:
if self.cache_enable:
sql = self.FETCH_SERIES_SQL % (
Expand All @@ -178,13 +185,15 @@ def __getitem__(self, index):
while query_result.has_next():
result.append(get_field_value(query_result.next().get_fields()[0]))
except Exception as e:
logger.error(e)
logger.error("Executing sql: {} with exception: {}".format(sql, e))
if self.cache_enable:
self.cache.put(cache_key, result)
result = torch.tensor(result)
return result[0 : self.input_len].unsqueeze(-1), result[
-self.output_len :
].unsqueeze(-1)
return (
result[0 : self.input_len],
result[-self.output_len :],
np.ones(self.token_num, dtype=np.int32),
)

def __len__(self):
return self.end_idx - self.start_idx
Expand Down Expand Up @@ -228,9 +237,9 @@ def __init__(

self.session = TableSession(table_session_config)
self.context_length = self.input_len + self.output_len
self.token_num = self.context_length // self.input_len
self._fetch_schema(data_schema_list)

v = self.total_count * start_split
self.start_index = int(self.total_count * start_split)
self.end_index = self.total_count * end_split

Expand Down Expand Up @@ -285,19 +294,52 @@ def __getitem__(self, index):
schema = series.split(".")

result = []
sql = self.FETCH_SERIES_SQL % (
schema[0:1],
schema[2],
window_index,
self.context_length,
)
try:
with self.session.execute_query_statement(
self.FETCH_SERIES_SQL
% (schema[0:1], schema[2], window_index, self.context_length)
) as query_result:
with self.session.execute_query_statement(sql) as query_result:
while query_result.has_next():
result.append(get_field_value(query_result.next().get_fields()[0]))
except Exception as e:
logger.error("Error happens when loading dataset str(e))")
logger.error("Executing sql: {} with exception: {}".format(sql, e))
result = torch.tensor(result)
return result[0 : self.input_len].unsqueeze(-1), result[
-self.output_len :
].unsqueeze(-1)
return (
result[0 : self.input_len],
result[-self.output_len :],
np.ones(self.token_num, dtype=np.int32),
)

def __len__(self):
return self.end_index - self.start_index


def register_dataset(key: str, dataset: Dataset):
DatasetFactory().register(key, dataset)


@singleton
class DatasetFactory(object):

def __init__(self):
self.dataset_list = {
"iotdb.table": IoTDBTableModelDataset,
"iotdb.tree": IoTDBTreeModelDataset,
}

def register(self, key: str, dataset: Dataset):
if key not in self.dataset_list:
self.dataset_list[key] = dataset
else:
raise KeyError(f"Dataset {key} already exists")

def deregister(self, key: str):
del self.dataset_list[key]

def get_dataset(self, key: str):
if key not in self.dataset_list.keys():
raise KeyError(f"Dataset {key} does not exist")
return self.dataset_list[key]
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId;
import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration;
import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.common.rpc.thrift.TFlushReq;
import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
Expand All @@ -41,7 +42,6 @@
import org.apache.iotdb.commons.auth.entity.PrivilegeUnion;
import org.apache.iotdb.commons.client.ainode.AINodeClient;
import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
import org.apache.iotdb.commons.client.ainode.AINodeInfo;
import org.apache.iotdb.commons.cluster.NodeStatus;
import org.apache.iotdb.commons.cluster.NodeType;
import org.apache.iotdb.commons.conf.CommonConfig;
Expand Down Expand Up @@ -136,6 +136,7 @@
import org.apache.iotdb.confignode.persistence.schema.ClusterSchemaInfo;
import org.apache.iotdb.confignode.persistence.subscription.SubscriptionInfo;
import org.apache.iotdb.confignode.procedure.impl.schema.SchemaUtils;
import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo;
import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq;
import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq;
import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp;
Expand Down Expand Up @@ -2671,7 +2672,7 @@ private List<IDataSchema> fetchSchemaForTableModel(TCreateTrainingReq req) {
}
}
for (String tableName : dataSchemaForTable.getTableList()) {
dataSchemaList.add(new IDataSchema(dataSchemaForTable.curDatabase + DOT + tableName));
dataSchemaList.add(new IDataSchema(tableName));
}
return dataSchemaList;
}
Expand All @@ -2685,7 +2686,7 @@ public TSStatus createTraining(TCreateTrainingReq req) {

TTrainingReq trainingReq = new TTrainingReq();
trainingReq.setModelId(req.getModelId());
trainingReq.setModelType("timer_xl");
trainingReq.setModelType("sundial");
if (req.existingModelId != null) {
trainingReq.setExistingModelId(req.getExistingModelId());
}
Expand All @@ -2710,8 +2711,11 @@ public TSStatus createTraining(TCreateTrainingReq req) {
updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.TRAINING.ordinal()));
trainingReq.setTargetDataSchema(dataSchema);

TAINodeInfo registeredAINode = getNodeManager().getRegisteredAINodeInfoList().get(0);
TEndPoint targetAINodeEndPoint =
new TEndPoint(registeredAINode.getInternalAddress(), registeredAINode.getInternalPort());
try (AINodeClient client =
AINodeClientManager.getInstance().borrowClient(AINodeInfo.endPoint)) {
AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) {
status = client.createTrainingTask(trainingReq);
if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
throw new IllegalArgumentException(status.message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1377,8 +1377,8 @@ protected IConfigTask visitCreateTraining(CreateTraining node, MPPQueryContext c
node.isUseAllData(),
node.getTargetTimeRanges(),
node.getExistingModelId(),
node.getTargetDbs(),
tableList);
tableList,
node.getTargetDbs());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ public CreateTrainingTask(
String existingModelId,
List<String> targetTables,
List<String> targetDbs) {
if (!modelType.equalsIgnoreCase("timer_xl")) {
throw new UnsupportedOperationException("Only TimerXL model is supported now.");
}
this.modelId = modelId;
this.modelType = modelType;
this.parameters = parameters;
Expand All @@ -76,9 +73,6 @@ public CreateTrainingTask(
List<List<Long>> timeRanges,
String existingModelId,
List<String> targetPaths) {
if (!modelType.equalsIgnoreCase("timer_xl")) {
throw new UnsupportedOperationException("Only TimerXL model is supported now.");
}
this.modelId = modelId;
this.modelType = modelType;
this.parameters = parameters;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ public TForecastResp forecast(
TSStatus tsStatus = new TSStatus(CAN_NOT_CONNECT_AINODE.getStatusCode());
tsStatus.setMessage(
String.format(
"Failed to connect to AINode from DataNode when executing %s: %s",
"Failed to connect to AINode when executing %s: %s",
Thread.currentThread().getStackTrace()[1].getMethodName(), e.getMessage()));
return new TForecastResp(tsStatus, ByteBuffer.allocate(0));
}
Expand All @@ -210,7 +210,7 @@ public TSStatus createTrainingTask(TTrainingReq req) throws TException {
return client.createTrainingTask(req);
} catch (TException e) {
logger.warn(
"Failed to connect to AINode from DataNode when executing {}: {}",
"Failed to connect to AINode when executing {}: {}",
Thread.currentThread().getStackTrace()[1].getMethodName(),
e.getMessage());
throw new TException(MSG_CONNECTION_FAIL);
Expand Down

This file was deleted.

Loading
Loading