@@ -31,15 +31,15 @@ def __init__(self,
3131
3232 self .logger : Logger = logger
3333
34- def _prepare_and_set_ts_sets (self , all_ts_ids : np .ndarray , all_ts_row_ranges : np .ndarray , ts_id_name : str , random_state ) -> None :
34+ def _prepare_and_set_ts_sets (self , all_ts_ids : np .ndarray , all_ts_row_ranges : np .ndarray , ts_id_name : str , random_state : Optional [ int ], rd : np . random . RandomState ) -> None :
3535 """Validates and filters the input time series IDs based on the `dataset` and `source_type`. Handles random split."""
3636
3737 random_ts_ids = all_ts_ids [ts_id_name ]
3838 random_indices = np .arange (len (all_ts_ids ))
3939
4040 # Process train_ts if it was specified with times series ids
4141 if self .train_ts is not None and not isinstance (self .train_ts , (float , int )):
42- self .train_ts , self .train_ts_row_ranges , _ = SeriesBasedHandler ._process_ts_ids (self .train_ts , all_ts_ids , all_ts_row_ranges , None , None , self .logger , ts_id_name , random_state )
42+ self .train_ts , self .train_ts_row_ranges , _ = SeriesBasedHandler ._process_ts_ids (self .train_ts , all_ts_ids , all_ts_row_ranges , None , None , self .logger , ts_id_name , random_state , rd )
4343
4444 mask = np .isin (random_ts_ids , self .train_ts , invert = True )
4545 random_ts_ids = random_ts_ids [mask ]
@@ -49,7 +49,7 @@ def _prepare_and_set_ts_sets(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np
4949
5050 # Process val_ts if it was specified with times series ids
5151 if self .val_ts is not None and not isinstance (self .val_ts , (float , int )):
52- self .val_ts , self .val_ts_row_ranges , _ = SeriesBasedHandler ._process_ts_ids (self .val_ts , all_ts_ids , all_ts_row_ranges , None , None , self .logger , ts_id_name , random_state )
52+ self .val_ts , self .val_ts_row_ranges , _ = SeriesBasedHandler ._process_ts_ids (self .val_ts , all_ts_ids , all_ts_row_ranges , None , None , self .logger , ts_id_name , random_state , rd )
5353
5454 mask = np .isin (random_ts_ids , self .val_ts , invert = True )
5555 random_ts_ids = random_ts_ids [mask ]
@@ -59,7 +59,7 @@ def _prepare_and_set_ts_sets(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np
5959
6060 # Process time_ts if it was specified with times series ids
6161 if self .test_ts is not None and not isinstance (self .test_ts , (float , int )):
62- self .test_ts , self .test_ts_row_ranges , _ = SeriesBasedHandler ._process_ts_ids (self .test_ts , all_ts_ids , all_ts_row_ranges , None , None , self .logger , ts_id_name , random_state )
62+ self .test_ts , self .test_ts_row_ranges , _ = SeriesBasedHandler ._process_ts_ids (self .test_ts , all_ts_ids , all_ts_row_ranges , None , None , self .logger , ts_id_name , random_state , rd )
6363
6464 mask = np .isin (random_ts_ids , self .test_ts , invert = True )
6565 random_ts_ids = random_ts_ids [mask ]
@@ -80,23 +80,23 @@ def _prepare_and_set_ts_sets(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np
8080
8181 # Process random train_ts if it is to be randomly made
8282 if isinstance (self .train_ts , int ):
83- self .train_ts , self .train_ts_row_ranges , random_indices = SeriesBasedHandler ._process_ts_ids (None , all_ts_ids , all_ts_row_ranges , self .train_ts , random_indices , self .logger , ts_id_name , random_state )
83+ self .train_ts , self .train_ts_row_ranges , random_indices = SeriesBasedHandler ._process_ts_ids (None , all_ts_ids , all_ts_row_ranges , self .train_ts , random_indices , self .logger , ts_id_name , random_state , rd )
8484 self .logger .debug ("Random train_ts set with %s time series." , self .train_ts )
8585
8686 # Process random val_ts if it is to be randomly made
8787 if isinstance (self .val_ts , int ):
88- self .val_ts , self .val_ts_row_ranges , random_indices = SeriesBasedHandler ._process_ts_ids (None , all_ts_ids , all_ts_row_ranges , self .val_ts , random_indices , self .logger , ts_id_name , random_state )
88+ self .val_ts , self .val_ts_row_ranges , random_indices = SeriesBasedHandler ._process_ts_ids (None , all_ts_ids , all_ts_row_ranges , self .val_ts , random_indices , self .logger , ts_id_name , random_state , rd )
8989 self .logger .debug ("Random val_ts set with %s time series." , self .val_ts )
9090
9191 # Process random test_ts if it is to be randomly made
9292 if isinstance (self .test_ts , int ):
93- self .test_ts , self .test_ts_row_ranges , random_indices = SeriesBasedHandler ._process_ts_ids (None , all_ts_ids , all_ts_row_ranges , self .test_ts , random_indices , self .logger , ts_id_name , random_state )
93+ self .test_ts , self .test_ts_row_ranges , random_indices = SeriesBasedHandler ._process_ts_ids (None , all_ts_ids , all_ts_row_ranges , self .test_ts , random_indices , self .logger , ts_id_name , random_state , rd )
9494 self .logger .debug ("Random test_ts set with %s time series." , self .test_ts )
9595
9696 if self .uses_all_ts :
9797 if self .train_ts is None and self .val_ts is None and self .test_ts is None :
9898 self .all_ts = all_ts_ids [ts_id_name ]
99- self .all_ts , self .all_ts_row_ranges , _ = SeriesBasedHandler ._process_ts_ids (self .all_ts , all_ts_ids , all_ts_row_ranges , None , None , self .logger , ts_id_name , random_state )
99+ self .all_ts , self .all_ts_row_ranges , _ = SeriesBasedHandler ._process_ts_ids (self .all_ts , all_ts_ids , all_ts_row_ranges , None , None , self .logger , ts_id_name , random_state , rd )
100100 self .logger .info ("Using all time series for all_ts because train_ts, val_ts, and test_ts are all set to None." )
101101 else :
102102 for temp_ts_ids in [self .train_ts , self .val_ts , self .test_ts ]:
@@ -114,7 +114,7 @@ def _prepare_and_set_ts_sets(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np
114114 if self .test_ts is not None :
115115 self .logger .debug ("all_ts includes ids from test_ts." )
116116
117- self .all_ts , self .all_ts_row_ranges , _ = self ._process_ts_ids (self .all_ts , all_ts_ids , all_ts_row_ranges , None , None , self .logger , ts_id_name , random_state )
117+ self .all_ts , self .all_ts_row_ranges , _ = self ._process_ts_ids (self .all_ts , all_ts_ids , all_ts_row_ranges , None , None , self .logger , ts_id_name , random_state , rd )
118118 else :
119119 self .all_ts = None
120120
@@ -163,7 +163,7 @@ def _validate_ts_overlap(self):
163163 raise ValueError ("Train, Val, and Test can't have the same IDs." )
164164
165165 @staticmethod
166- def _process_ts_ids (ts_ids : np .ndarray , all_ts_ids : np .ndarray , all_ts_row_ranges : np .ndarray , split_size : float | int | None , random_indices : np .ndarray , logger : Logger , ts_id_name : str , random_state ) -> None :
166+ def _process_ts_ids (ts_ids : np .ndarray , all_ts_ids : np .ndarray , all_ts_row_ranges : np .ndarray , split_size : float | int | None , random_indices : np .ndarray , logger : Logger , ts_id_name : str , random_state : Optional [ int ], rd : np . random . RandomState ) -> None :
167167 """Validates and filters the input `ts_ids` based on the `dataset` and `source_type`. """
168168
169169 if ts_ids is None and split_size is None :
@@ -175,7 +175,7 @@ def _process_ts_ids(ts_ids: np.ndarray, all_ts_ids: np.ndarray, all_ts_row_range
175175 raise ValueError (f"Trying to use more time series than there are in the dataset. There are { len (all_ts_ids )} time series available." )
176176
177177 if split_size == len (random_indices ):
178- np . random .shuffle (random_indices )
178+ rd .shuffle (random_indices )
179179 ts_indices = random_indices
180180 ts_ids = all_ts_ids [ts_id_name ][ts_indices ]
181181 random_indices = np .array ([]) # No remaining indices
0 commit comments