diff --git a/iotdb-core/ainode/ainode/core/ingress/dataset.py b/iotdb-core/ainode/ainode/core/ingress/dataset.py index 316c4235067f..4e3b5293c169 100644 --- a/iotdb-core/ainode/ainode/core/ingress/dataset.py +++ b/iotdb-core/ainode/ainode/core/ingress/dataset.py @@ -33,6 +33,7 @@ def __init__( seq_len: int, input_token_len: int, output_token_len: int, + window_step: int, ): super().__init__(ip, port) # The number of the time series data points of the model input @@ -42,3 +43,4 @@ def __init__( # The number of the time series data points of the model output self.output_token_len = output_token_len self.token_num = self.seq_len // self.input_token_len + self.window_step = window_step diff --git a/iotdb-core/ainode/ainode/core/ingress/iotdb.py b/iotdb-core/ainode/ainode/core/ingress/iotdb.py index b9e844d91932..528c3cb73979 100644 --- a/iotdb-core/ainode/ainode/core/ingress/iotdb.py +++ b/iotdb-core/ainode/ainode/core/ingress/iotdb.py @@ -59,6 +59,7 @@ def __init__( seq_len: int, input_token_len: int, output_token_len: int, + window_step: int, data_schema_list: list, ip: str = AINodeDescriptor().get_config().get_ain_cluster_ingress_address(), port: int = AINodeDescriptor().get_config().get_ain_cluster_ingress_port(), @@ -74,7 +75,9 @@ def __init__( use_rate: float = 1.0, offset_rate: float = 0.0, ): - super().__init__(ip, port, seq_len, input_token_len, output_token_len) + super().__init__( + ip, port, seq_len, input_token_len, output_token_len, window_step + ) self.SHOW_TIMESERIES = "show timeseries %s%s" self.COUNT_SERIES_SQL = "select count(%s) from %s%s" @@ -139,7 +142,9 @@ def _fetch_schema(self, data_schema_list: list): window_sum = 0 for seq_name, seq_value in sorted_series: # calculate and sum the number of training data windows for each time series - window_count = seq_value[1] - self.seq_len - self.output_token_len + 1 + window_count = ( + seq_value[1] - self.seq_len - self.output_token_len + 1 + ) // self.window_step if window_count <= 1: continue use_window_count = int(window_count * self.use_rate) @@ -176,14 +181,16 @@ def __getitem__(self, index): # try to get the training data window from cache first series_data = torch.tensor(series_data) result = series_data[ - window_index : window_index + self.seq_len + self.output_token_len + window_index * self.window_step : window_index * self.window_step + + self.seq_len + + self.output_token_len ] return ( result[0 : self.seq_len], result[self.input_token_len : self.seq_len + self.output_token_len], np.ones(self.token_num, dtype=np.int32), ) - result = [] + series_data = [] sql = "" try: if self.cache_enable: @@ -204,12 +211,18 @@ def __getitem__(self, index): ) with self.session.execute_query_statement(sql) as query_result: while query_result.has_next(): - result.append(get_field_value(query_result.next().get_fields()[0])) + series_data.append( + get_field_value(query_result.next().get_fields()[0]) + ) except Exception as e: logger.error("Executing sql: {} with exception: {}".format(sql, e)) if self.cache_enable: - self.cache.put(cache_key, result) - result = torch.tensor(result) + self.cache.put(cache_key, series_data) + result = series_data[ + window_index * self.window_step : window_index * self.window_step + + self.seq_len + + self.output_token_len + ] return ( result[0 : self.seq_len], result[self.input_token_len : self.seq_len + self.output_token_len], @@ -230,6 +243,7 @@ def __init__( seq_len: int, input_token_len: int, output_token_len: int, + window_step: int, data_schema_list: list, ip: str = AINodeDescriptor().get_config().get_ain_cluster_ingress_address(), port: int = AINodeDescriptor().get_config().get_ain_cluster_ingress_port(), @@ -245,7 +259,9 @@ def __init__( use_rate: float = 1.0, offset_rate: float = 0.0, ): - super().__init__(ip, port, seq_len, input_token_len, output_token_len) + super().__init__( + ip, port, seq_len, input_token_len, output_token_len, window_step + ) table_session_config = TableSessionConfig( node_urls=[f"{ip}:{port}"], @@ -302,7 +318,9 @@ def _fetch_schema(self, data_schema_list: list): window_sum = 0 for seq_name, seq_values in series_map.items(): # calculate and sum the number of training data windows for each time series - window_count = len(seq_values) - self.seq_len - self.output_token_len + 1 + window_count = ( + len(seq_values) - self.seq_len - self.output_token_len + 1 + ) // self.window_step if window_count <= 1: continue use_window_count = int(window_count * self.use_rate) @@ -331,7 +349,9 @@ def __getitem__(self, index): window_index -= self.series_with_prefix_sum[series_index - 1][2] window_index += self.series_with_prefix_sum[series_index][3] result = self.series_with_prefix_sum[series_index][4][ - window_index : window_index + self.seq_len + self.output_token_len + window_index * self.window_step : window_index * self.window_step + + self.seq_len + + self.output_token_len ] result = torch.tensor(result) return (