Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions iotdb-core/ainode/ainode/core/ingress/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
seq_len: int,
input_token_len: int,
output_token_len: int,
window_step: int,
):
super().__init__(ip, port)
# The number of the time series data points of the model input
Expand All @@ -42,3 +43,4 @@ def __init__(
# The number of the time series data points of the model output
self.output_token_len = output_token_len
self.token_num = self.seq_len // self.input_token_len
self.window_step = window_step
40 changes: 30 additions & 10 deletions iotdb-core/ainode/ainode/core/ingress/iotdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
seq_len: int,
input_token_len: int,
output_token_len: int,
window_step: int,
data_schema_list: list,
ip: str = AINodeDescriptor().get_config().get_ain_cluster_ingress_address(),
port: int = AINodeDescriptor().get_config().get_ain_cluster_ingress_port(),
Expand All @@ -74,7 +75,9 @@ def __init__(
use_rate: float = 1.0,
offset_rate: float = 0.0,
):
super().__init__(ip, port, seq_len, input_token_len, output_token_len)
super().__init__(
ip, port, seq_len, input_token_len, output_token_len, window_step
)

self.SHOW_TIMESERIES = "show timeseries %s%s"
self.COUNT_SERIES_SQL = "select count(%s) from %s%s"
Expand Down Expand Up @@ -139,7 +142,9 @@ def _fetch_schema(self, data_schema_list: list):
window_sum = 0
for seq_name, seq_value in sorted_series:
# calculate and sum the number of training data windows for each time series
window_count = seq_value[1] - self.seq_len - self.output_token_len + 1
window_count = (
seq_value[1] - self.seq_len - self.output_token_len + 1
) // self.window_step
if window_count <= 1:
continue
use_window_count = int(window_count * self.use_rate)
Expand Down Expand Up @@ -176,14 +181,16 @@ def __getitem__(self, index):
# try to get the training data window from cache first
series_data = torch.tensor(series_data)
result = series_data[
window_index : window_index + self.seq_len + self.output_token_len
window_index * self.window_step : window_index * self.window_step
+ self.seq_len
+ self.output_token_len
]
return (
result[0 : self.seq_len],
result[self.input_token_len : self.seq_len + self.output_token_len],
np.ones(self.token_num, dtype=np.int32),
)
result = []
series_data = []
sql = ""
try:
if self.cache_enable:
Expand All @@ -204,12 +211,18 @@ def __getitem__(self, index):
)
with self.session.execute_query_statement(sql) as query_result:
while query_result.has_next():
result.append(get_field_value(query_result.next().get_fields()[0]))
series_data.append(
get_field_value(query_result.next().get_fields()[0])
)
except Exception as e:
logger.error("Executing sql: {} with exception: {}".format(sql, e))
if self.cache_enable:
self.cache.put(cache_key, result)
result = torch.tensor(result)
self.cache.put(cache_key, series_data)
result = series_data[
window_index * self.window_step : window_index * self.window_step
+ self.seq_len
+ self.output_token_len
]
return (
result[0 : self.seq_len],
result[self.input_token_len : self.seq_len + self.output_token_len],
Expand All @@ -230,6 +243,7 @@ def __init__(
seq_len: int,
input_token_len: int,
output_token_len: int,
window_step: int,
data_schema_list: list,
ip: str = AINodeDescriptor().get_config().get_ain_cluster_ingress_address(),
port: int = AINodeDescriptor().get_config().get_ain_cluster_ingress_port(),
Expand All @@ -245,7 +259,9 @@ def __init__(
use_rate: float = 1.0,
offset_rate: float = 0.0,
):
super().__init__(ip, port, seq_len, input_token_len, output_token_len)
super().__init__(
ip, port, seq_len, input_token_len, output_token_len, window_step
)

table_session_config = TableSessionConfig(
node_urls=[f"{ip}:{port}"],
Expand Down Expand Up @@ -302,7 +318,9 @@ def _fetch_schema(self, data_schema_list: list):
window_sum = 0
for seq_name, seq_values in series_map.items():
# calculate and sum the number of training data windows for each time series
window_count = len(seq_values) - self.seq_len - self.output_token_len + 1
window_count = (
len(seq_values) - self.seq_len - self.output_token_len + 1
) // self.window_step
if window_count <= 1:
continue
use_window_count = int(window_count * self.use_rate)
Expand Down Expand Up @@ -331,7 +349,9 @@ def __getitem__(self, index):
window_index -= self.series_with_prefix_sum[series_index - 1][2]
window_index += self.series_with_prefix_sum[series_index][3]
result = self.series_with_prefix_sum[series_index][4][
window_index : window_index + self.seq_len + self.output_token_len
window_index * self.window_step : window_index * self.window_step
+ self.seq_len
+ self.output_token_len
]
result = torch.tensor(result)
return (
Expand Down
Loading