@@ -59,6 +59,7 @@ def __init__(
5959 seq_len : int ,
6060 input_token_len : int ,
6161 output_token_len : int ,
62+ window_step : int ,
6263 data_schema_list : list ,
6364 ip : str = AINodeDescriptor ().get_config ().get_ain_cluster_ingress_address (),
6465 port : int = AINodeDescriptor ().get_config ().get_ain_cluster_ingress_port (),
@@ -74,7 +75,9 @@ def __init__(
7475 use_rate : float = 1.0 ,
7576 offset_rate : float = 0.0 ,
7677 ):
77- super ().__init__ (ip , port , seq_len , input_token_len , output_token_len )
78+ super ().__init__ (
79+ ip , port , seq_len , input_token_len , output_token_len , window_step
80+ )
7881
7982 self .SHOW_TIMESERIES = "show timeseries %s%s"
8083 self .COUNT_SERIES_SQL = "select count(%s) from %s%s"
@@ -139,7 +142,9 @@ def _fetch_schema(self, data_schema_list: list):
139142 window_sum = 0
140143 for seq_name , seq_value in sorted_series :
141144 # calculate and sum the number of training data windows for each time series
142- window_count = seq_value [1 ] - self .seq_len - self .output_token_len + 1
145+ window_count = (
146+ seq_value [1 ] - self .seq_len - self .output_token_len + 1
147+ ) // self .window_step
143148 if window_count <= 1 :
144149 continue
145150 use_window_count = int (window_count * self .use_rate )
@@ -176,14 +181,16 @@ def __getitem__(self, index):
176181 # try to get the training data window from cache first
177182 series_data = torch .tensor (series_data )
178183 result = series_data [
179- window_index : window_index + self .seq_len + self .output_token_len
184+ window_index * self .window_step : window_index * self .window_step
185+ + self .seq_len
186+ + self .output_token_len
180187 ]
181188 return (
182189 result [0 : self .seq_len ],
183190 result [self .input_token_len : self .seq_len + self .output_token_len ],
184191 np .ones (self .token_num , dtype = np .int32 ),
185192 )
186- result = []
193+ series_data = []
187194 sql = ""
188195 try :
189196 if self .cache_enable :
@@ -204,12 +211,18 @@ def __getitem__(self, index):
204211 )
205212 with self .session .execute_query_statement (sql ) as query_result :
206213 while query_result .has_next ():
207- result .append (get_field_value (query_result .next ().get_fields ()[0 ]))
214+ series_data .append (
215+ get_field_value (query_result .next ().get_fields ()[0 ])
216+ )
208217 except Exception as e :
209218 logger .error ("Executing sql: {} with exception: {}" .format (sql , e ))
210219 if self .cache_enable :
211- self .cache .put (cache_key , result )
212- result = torch .tensor (result )
220+ self .cache .put (cache_key , series_data )
221+ result = series_data [
222+ window_index * self .window_step : window_index * self .window_step
223+ + self .seq_len
224+ + self .output_token_len
225+ ]
213226 return (
214227 result [0 : self .seq_len ],
215228 result [self .input_token_len : self .seq_len + self .output_token_len ],
@@ -230,6 +243,7 @@ def __init__(
230243 seq_len : int ,
231244 input_token_len : int ,
232245 output_token_len : int ,
246+ window_step : int ,
233247 data_schema_list : list ,
234248 ip : str = AINodeDescriptor ().get_config ().get_ain_cluster_ingress_address (),
235249 port : int = AINodeDescriptor ().get_config ().get_ain_cluster_ingress_port (),
@@ -245,7 +259,9 @@ def __init__(
245259 use_rate : float = 1.0 ,
246260 offset_rate : float = 0.0 ,
247261 ):
248- super ().__init__ (ip , port , seq_len , input_token_len , output_token_len )
262+ super ().__init__ (
263+ ip , port , seq_len , input_token_len , output_token_len , window_step
264+ )
249265
250266 table_session_config = TableSessionConfig (
251267 node_urls = [f"{ ip } :{ port } " ],
@@ -302,7 +318,9 @@ def _fetch_schema(self, data_schema_list: list):
302318 window_sum = 0
303319 for seq_name , seq_values in series_map .items ():
304320 # calculate and sum the number of training data windows for each time series
305- window_count = len (seq_values ) - self .seq_len - self .output_token_len + 1
321+ window_count = (
322+ len (seq_values ) - self .seq_len - self .output_token_len + 1
323+ ) // self .window_step
306324 if window_count <= 1 :
307325 continue
308326 use_window_count = int (window_count * self .use_rate )
@@ -331,7 +349,9 @@ def __getitem__(self, index):
331349 window_index -= self .series_with_prefix_sum [series_index - 1 ][2 ]
332350 window_index += self .series_with_prefix_sum [series_index ][3 ]
333351 result = self .series_with_prefix_sum [series_index ][4 ][
334- window_index : window_index + self .seq_len + self .output_token_len
352+ window_index * self .window_step : window_index * self .window_step
353+ + self .seq_len
354+ + self .output_token_len
335355 ]
336356 result = torch .tensor (result )
337357 return (
0 commit comments