diff --git a/iotdb-core/ainode/ainode/core/config.py b/iotdb-core/ainode/ainode/core/config.py index 62de76fcbb839..edcdaf559be3c 100644 --- a/iotdb-core/ainode/ainode/core/config.py +++ b/iotdb-core/ainode/ainode/core/config.py @@ -52,6 +52,9 @@ def __init__(self): # log directory self._ain_logs_dir: str = AINODE_LOG_DIR + # cache size for ingress dataloader (MB) + self._ain_data_cache_size = 50 + # Directory to save models self._ain_models_dir = AINODE_MODELS_DIR @@ -94,6 +97,12 @@ def get_build_info(self) -> str: def set_build_info(self, build_info: str) -> None: self._build_info = build_info + def get_ain_data_storage_cache_size(self) -> int: + return self._ain_data_cache_size + + def set_ain_data_cache_size(self, ain_data_cache_size: int) -> None: + self._ain_data_cache_size = ain_data_cache_size + def set_version_info(self, version_info: str) -> None: self._version_info = version_info diff --git a/iotdb-core/ainode/ainode/core/ingress/__init__.py b/iotdb-core/ainode/ainode/core/ingress/__init__.py new file mode 100644 index 0000000000000..2a1e720805f29 --- /dev/null +++ b/iotdb-core/ainode/ainode/core/ingress/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/ainode/core/ingress/dataset.py b/iotdb-core/ainode/ainode/core/ingress/dataset.py new file mode 100644 index 0000000000000..c2410ed4374d9 --- /dev/null +++ b/iotdb-core/ainode/ainode/core/ingress/dataset.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from torch.utils.data import Dataset + +from ainode.core.ingress.iotdb import IoTDBTableModelDataset, IoTDBTreeModelDataset +from ainode.core.util.decorator import singleton + + +class BasicDatabaseDataset(Dataset): + def __init__(self, ip: str, port: int): + self.ip = ip + self.port = port + + +class BasicDatabaseForecastDataset(BasicDatabaseDataset): + def __init__(self, ip: str, port: int, input_len: int, output_len: int): + super().__init__(ip, port) + self.input_len = input_len + self.output_len = output_len + + +def register_dataset(key: str, dataset: Dataset): + DatasetFactory().register(key, dataset) + + +@singleton +class DatasetFactory(object): + + def __init__(self): + self.dataset_list = { + "iotdb.table": IoTDBTableModelDataset, + "iotdb.tree": IoTDBTreeModelDataset, + } + + def register(self, key: str, dataset: Dataset): + if key not in self.dataset_list: + self.dataset_list[key] = dataset + else: + raise KeyError(f"Dataset {key} already exists") + + def deregister(self, key: str): + del self.dataset_list[key] + + def get_dataset(self, key: str): + if key not in self.dataset_list.keys(): + raise KeyError(f"Dataset {key} does not exist") + return self.dataset_list[key] diff --git a/iotdb-core/ainode/ainode/core/ingress/iotdb.py b/iotdb-core/ainode/ainode/core/ingress/iotdb.py new file mode 100644 index 0000000000000..4b034ac880842 --- /dev/null +++ b/iotdb-core/ainode/ainode/core/ingress/iotdb.py @@ -0,0 +1,303 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import torch +from iotdb.Session import Session +from iotdb.table_session import TableSession, TableSessionConfig +from iotdb.utils.Field import Field +from iotdb.utils.IoTDBConstants import TSDataType +from util.cache import MemoryLRUCache + +from ainode.core.config import AINodeDescriptor +from ainode.core.ingress.dataset import BasicDatabaseForecastDataset +from ainode.core.log import Logger + +logger = Logger() + + +def get_field_value(field: Field): + data_type = field.get_data_type() + if data_type == TSDataType.INT32: + return field.get_int_value() + elif data_type == TSDataType.INT64: + return field.get_long_value() + elif data_type == TSDataType.FLOAT: + return field.get_float_value() + elif data_type == TSDataType.DOUBLE: + return field.get_double_value() + else: + return field.get_string_value() + + +def _cache_enable() -> bool: + return AINodeDescriptor().get_config().get_ain_data_storage_cache_size() > 0 + + +class IoTDBTreeModelDataset(BasicDatabaseForecastDataset): + cache = MemoryLRUCache() + + def __init__( + self, + model_id: str, + input_len: int, + out_len: int, + schema_list: list, + ip: str = "127.0.0.1", + port: int = 6667, + username: str = "root", + password: str = "root", + time_zone: str = "UTC+8", + start_split: float = 0, + end_split: float = 1, + ): + super().__init__(ip, port, input_len, out_len) + + self.SHOW_TIMESERIES = "show timeseries %s%s" + self.COUNT_SERIES_SQL = "select count(%s) from %s%s" + self.FETCH_SERIES_SQL = "select %s from %s%s" + self.FETCH_SERIES_RANGE_SQL = "select %s from %s offset %s limit %s%s" + + self.TIME_CONDITION = " where time>%s and time<%s" + + self.session = Session.init_from_node_urls( + node_urls=[f"{ip}:{port}"], + user=username, + password=password, + zone_id=time_zone, + ) + self.session.open(False) + self.context_length = self.input_len + self.output_len + self._fetch_schema(schema_list) + self.start_idx = int(self.total_count * start_split) + self.end_idx = int(self.total_count * end_split) + self.cache_enable = _cache_enable() + self.cache_key_prefix = model_id + "_" + + def _fetch_schema(self, schema_list: list): + series_to_length = {} + for schema in schema_list: + path_pattern = schema.schemaName + series_list = [] + time_condition = ( + self.TIME_CONDITION % (schema.timeRange[0], schema.timeRange[1]) + if schema.timeRange + else "" + ) + with self.session.execute_query_statement( + self.SHOW_TIMESERIES % (path_pattern, time_condition) + ) as show_timeseries_result: + while show_timeseries_result.has_next(): + series_list.append( + get_field_value(show_timeseries_result.next().get_fields()[0]) + ) + + for series in series_list: + split_series = series.split(".") + with self.session.execute_query_statement( + self.COUNT_SERIES_SQL + % (split_series[-1], ".".join(split_series[:-1]), time_condition) + ) as count_series_result: + while count_series_result.has_next(): + length = get_field_value( + count_series_result.next().get_fields()[0] + ) + series_to_length[series] = ( + split_series, + length, + time_condition, + ) + + sorted_series = sorted(series_to_length.items(), key=lambda x: x[1][1]) + sorted_series_with_prefix_sum = [] + window_sum = 0 + for seq_name, seq_value in sorted_series: + window_count = seq_value[1] - self.context_length + 1 + if window_count <= 0: + continue + window_sum += window_count + sorted_series_with_prefix_sum.append( + (seq_value[0], window_count, window_sum, seq_value[2]) + ) + + self.total_count = window_sum + self.sorted_series = sorted_series_with_prefix_sum + + def __getitem__(self, index): + window_index = index + series_index = 0 + while self.sorted_series[series_index][2] < window_index: + series_index += 1 + + if series_index != 0: + window_index -= self.sorted_series[series_index - 1][2] + + if window_index != 0: + window_index -= 1 + series = self.sorted_series[series_index][0] + time_condition = self.sorted_series[series_index][3] + if self.cache_enable: + cache_key = self.cache_key_prefix + ".".join(series) + time_condition + series_data = self.cache.get(cache_key) + if series_data is not None: + series_data = torch.tensor(series_data) + result = series_data[window_index : window_index + self.context_length] + return result[0 : self.input_len].unsqueeze(-1), result[ + -self.output_len : + ].unsqueeze(-1) + result = [] + try: + if self.cache_enable: + sql = self.FETCH_SERIES_SQL % ( + series[-1], + ".".join(series[0:-1]), + time_condition, + ) + else: + sql = self.FETCH_SERIES_RANGE_SQL % ( + series[-1], + ".".join(series[0:-1]), + window_index, + self.context_length, + time_condition, + ) + with self.session.execute_query_statement(sql) as query_result: + while query_result.has_next(): + result.append(get_field_value(query_result.next().get_fields()[0])) + except Exception as e: + logger.error(e) + if self.cache_enable: + self.cache.put(cache_key, result) + result = torch.tensor(result) + return result[0 : self.input_len].unsqueeze(-1), result[ + -self.output_len : + ].unsqueeze(-1) + + def __len__(self): + return self.end_idx - self.start_idx + + +class IoTDBTableModelDataset(BasicDatabaseForecastDataset): + + def __init__( + self, + input_len: int, + out_len: int, + data_schema_list: list, + ip: str = "127.0.0.1", + port: int = 6667, + username: str = "root", + password: str = "root", + time_zone: str = "UTC+8", + start_split: float = 0, + end_split: float = 1, + ): + super().__init__(ip, port, input_len, out_len) + if end_split < start_split: + raise ValueError("end_split must be greater than start_split") + + # database , table + self.SELECT_SERIES_FORMAT_SQL = "select distinct item_id from %s" + self.COUNT_SERIES_LENGTH_SQL = ( + "select count(value) from %s where item_id = '%s'" + ) + self.FETCH_SERIES_SQL = ( + "select value from %s where item_id = '%s' offset %s limit %s" + ) + self.SERIES_NAME = "%s.%s" + + table_session_config = TableSessionConfig( + node_urls=[f"{ip}:{port}"], + username=username, + password=password, + time_zone=time_zone, + ) + + self.session = TableSession(table_session_config) + self.context_length = self.input_len + self.output_len + self._fetch_schema(data_schema_list) + + v = self.total_count * start_split + self.start_index = int(self.total_count * start_split) + self.end_index = self.total_count * end_split + + def _fetch_schema(self, data_schema_list: list): + series_to_length = {} + for data_schema in data_schema_list: + series_list = [] + with self.session.execute_query_statement( + self.SELECT_SERIES_FORMAT_SQL % data_schema + ) as show_devices_result: + while show_devices_result.has_next(): + series_list.append( + get_field_value(show_devices_result.next().get_fields()[0]) + ) + + for series in series_list: + with self.session.execute_query_statement( + self.COUNT_SERIES_LENGTH_SQL % (data_schema.schemaName, series) + ) as count_series_result: + length = get_field_value(count_series_result.next().get_fields()[0]) + series_to_length[ + self.SERIES_NAME % (data_schema.schemaName, series) + ] = length + + sorted_series = sorted(series_to_length.items(), key=lambda x: x[1]) + sorted_series_with_prefix_sum = [] + window_sum = 0 + for seq_name, seq_length in sorted_series: + window_count = seq_length - self.context_length + 1 + if window_count < 0: + continue + window_sum += window_count + sorted_series_with_prefix_sum.append((seq_name, window_count, window_sum)) + + self.total_count = window_sum + self.sorted_series = sorted_series_with_prefix_sum + + def __getitem__(self, index): + window_index = index + + series_index = 0 + + while self.sorted_series[series_index][2] < window_index: + series_index += 1 + + if series_index != 0: + window_index -= self.sorted_series[series_index - 1][2] + + if window_index != 0: + window_index -= 1 + series = self.sorted_series[series_index][0] + schema = series.split(".") + + result = [] + try: + with self.session.execute_query_statement( + self.FETCH_SERIES_SQL + % (schema[0:1], schema[2], window_index, self.context_length) + ) as query_result: + while query_result.has_next(): + result.append(get_field_value(query_result.next().get_fields()[0])) + except Exception as e: + logger.error("Error happens when loading dataset str(e))") + result = torch.tensor(result) + return result[0 : self.input_len].unsqueeze(-1), result[ + -self.output_len : + ].unsqueeze(-1) + + def __len__(self): + return self.end_index - self.start_index diff --git a/iotdb-core/ainode/ainode/core/util/cache.py b/iotdb-core/ainode/ainode/core/util/cache.py new file mode 100644 index 0000000000000..e0cfd5883167a --- /dev/null +++ b/iotdb-core/ainode/ainode/core/util/cache.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import sys +from collections import OrderedDict + +from ainode.core.config import AINodeDescriptor +from ainode.core.util.decorator import singleton + + +def _estimate_size_in_byte(obj): + if isinstance(obj, str): + return len(obj) + 49 + elif isinstance(obj, int): + return 28 + elif isinstance(obj, list): + return 64 + sum(_estimate_size_in_byte(x) for x in obj) + elif isinstance(obj, dict): + return 280 + sum( + _estimate_size_in_byte(k) + _estimate_size_in_byte(v) + for k, v in obj.items() + ) + else: + return sys.getsizeof(obj) + + +def _get_item_memory(key, value) -> int: + return _estimate_size_in_byte(key) + _estimate_size_in_byte(value) + + +@singleton +class MemoryLRUCache: + def __init__(self): + self.cache = OrderedDict() + self.max_memory_bytes = ( + AINodeDescriptor().get_config().get_ain_data_storage_cache_size() + * 1024 + * 1024 + ) + self.current_memory = 0 + + def get(self, key): + if key not in self.cache: + return None + value = self.cache[key] + self.cache.move_to_end(key) + return value + + def put(self, key, value): + item_memory = _get_item_memory(key, value) + + if key in self.cache: + old_value = self.cache[key] + old_memory = _get_item_memory(key, old_value) + self.current_memory -= old_memory + self.current_memory += item_memory + self._evict_if_needed() + self.cache[key] = value + self.cache.move_to_end(key) + else: + self.current_memory += item_memory + self._evict_if_needed() + self.cache[key] = value + + def _evict_if_needed(self): + while self.current_memory > self.max_memory_bytes: + if not self.cache: + break + key, value = self.cache.popitem(last=False) + removed_memory = _get_item_memory(key, value) + self.current_memory -= removed_memory + + def get_current_memory_mb(self) -> float: + return self.current_memory / (1024 * 1024)