Skip to content

Commit 03e7018

Browse files
authored
[AINode] Add window_step options for dataset (apache#15857)
1 parent d86b86e commit 03e7018

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

iotdb-core/ainode/ainode/core/ingress/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
seq_len: int,
3434
input_token_len: int,
3535
output_token_len: int,
36+
window_step: int,
3637
):
3738
super().__init__(ip, port)
3839
# The number of the time series data points of the model input
@@ -42,3 +43,4 @@ def __init__(
4243
# The number of the time series data points of the model output
4344
self.output_token_len = output_token_len
4445
self.token_num = self.seq_len // self.input_token_len
46+
self.window_step = window_step

iotdb-core/ainode/ainode/core/ingress/iotdb.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
seq_len: int,
6060
input_token_len: int,
6161
output_token_len: int,
62+
window_step: int,
6263
data_schema_list: list,
6364
ip: str = AINodeDescriptor().get_config().get_ain_cluster_ingress_address(),
6465
port: int = AINodeDescriptor().get_config().get_ain_cluster_ingress_port(),
@@ -74,7 +75,9 @@ def __init__(
7475
use_rate: float = 1.0,
7576
offset_rate: float = 0.0,
7677
):
77-
super().__init__(ip, port, seq_len, input_token_len, output_token_len)
78+
super().__init__(
79+
ip, port, seq_len, input_token_len, output_token_len, window_step
80+
)
7881

7982
self.SHOW_TIMESERIES = "show timeseries %s%s"
8083
self.COUNT_SERIES_SQL = "select count(%s) from %s%s"
@@ -139,7 +142,9 @@ def _fetch_schema(self, data_schema_list: list):
139142
window_sum = 0
140143
for seq_name, seq_value in sorted_series:
141144
# calculate and sum the number of training data windows for each time series
142-
window_count = seq_value[1] - self.seq_len - self.output_token_len + 1
145+
window_count = (
146+
seq_value[1] - self.seq_len - self.output_token_len + 1
147+
) // self.window_step
143148
if window_count <= 1:
144149
continue
145150
use_window_count = int(window_count * self.use_rate)
@@ -176,14 +181,16 @@ def __getitem__(self, index):
176181
# try to get the training data window from cache first
177182
series_data = torch.tensor(series_data)
178183
result = series_data[
179-
window_index : window_index + self.seq_len + self.output_token_len
184+
window_index * self.window_step : window_index * self.window_step
185+
+ self.seq_len
186+
+ self.output_token_len
180187
]
181188
return (
182189
result[0 : self.seq_len],
183190
result[self.input_token_len : self.seq_len + self.output_token_len],
184191
np.ones(self.token_num, dtype=np.int32),
185192
)
186-
result = []
193+
series_data = []
187194
sql = ""
188195
try:
189196
if self.cache_enable:
@@ -204,12 +211,18 @@ def __getitem__(self, index):
204211
)
205212
with self.session.execute_query_statement(sql) as query_result:
206213
while query_result.has_next():
207-
result.append(get_field_value(query_result.next().get_fields()[0]))
214+
series_data.append(
215+
get_field_value(query_result.next().get_fields()[0])
216+
)
208217
except Exception as e:
209218
logger.error("Executing sql: {} with exception: {}".format(sql, e))
210219
if self.cache_enable:
211-
self.cache.put(cache_key, result)
212-
result = torch.tensor(result)
220+
self.cache.put(cache_key, series_data)
221+
result = series_data[
222+
window_index * self.window_step : window_index * self.window_step
223+
+ self.seq_len
224+
+ self.output_token_len
225+
]
213226
return (
214227
result[0 : self.seq_len],
215228
result[self.input_token_len : self.seq_len + self.output_token_len],
@@ -230,6 +243,7 @@ def __init__(
230243
seq_len: int,
231244
input_token_len: int,
232245
output_token_len: int,
246+
window_step: int,
233247
data_schema_list: list,
234248
ip: str = AINodeDescriptor().get_config().get_ain_cluster_ingress_address(),
235249
port: int = AINodeDescriptor().get_config().get_ain_cluster_ingress_port(),
@@ -245,7 +259,9 @@ def __init__(
245259
use_rate: float = 1.0,
246260
offset_rate: float = 0.0,
247261
):
248-
super().__init__(ip, port, seq_len, input_token_len, output_token_len)
262+
super().__init__(
263+
ip, port, seq_len, input_token_len, output_token_len, window_step
264+
)
249265

250266
table_session_config = TableSessionConfig(
251267
node_urls=[f"{ip}:{port}"],
@@ -302,7 +318,9 @@ def _fetch_schema(self, data_schema_list: list):
302318
window_sum = 0
303319
for seq_name, seq_values in series_map.items():
304320
# calculate and sum the number of training data windows for each time series
305-
window_count = len(seq_values) - self.seq_len - self.output_token_len + 1
321+
window_count = (
322+
len(seq_values) - self.seq_len - self.output_token_len + 1
323+
) // self.window_step
306324
if window_count <= 1:
307325
continue
308326
use_window_count = int(window_count * self.use_rate)
@@ -331,7 +349,9 @@ def __getitem__(self, index):
331349
window_index -= self.series_with_prefix_sum[series_index - 1][2]
332350
window_index += self.series_with_prefix_sum[series_index][3]
333351
result = self.series_with_prefix_sum[series_index][4][
334-
window_index : window_index + self.seq_len + self.output_token_len
352+
window_index * self.window_step : window_index * self.window_step
353+
+ self.seq_len
354+
+ self.output_token_len
335355
]
336356
result = torch.tensor(result)
337357
return (

0 commit comments

Comments
 (0)