Skip to content

Commit 4ae5201

Browse files
committed
Fix: Setting random_state now does not effect global numpy random
1 parent c3e2745 commit 4ae5201

File tree

6 files changed

+26
-23
lines changed

6 files changed

+26
-23
lines changed

cesnet_tszoo/configs/base_config.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,6 @@ def __init__(self,
185185
self.train_dataloader_order: DataloaderOrder = train_dataloader_order
186186
self.random_state: Optional[int] = random_state
187187

188-
if self.random_state is not None:
189-
np.random.seed(random_state)
190-
191188
self._validate_construction()
192189

193190
self.logger.info("Quick validation succeeded.")
@@ -364,12 +361,14 @@ def _update_identifiers_from_dataset_metadata(self, dataset_metadata: DatasetMet
364361
def _dataset_init(self, dataset_metadata: DatasetMetadata) -> None:
365362
"""Performs deeper parameter validation and updates values based on data from the dataset. """
366363

364+
rd = np.random.RandomState(self.random_state)
365+
367366
self.ts_id_name = dataset_metadata.ts_id_name
368367

369368
self._set_features_to_take(dataset_metadata.features)
370369
self.logger.debug("Features to take have been successfully set.")
371370

372-
self._set_ts(dataset_metadata.ts_indices, dataset_metadata.ts_row_ranges)
371+
self._set_ts(dataset_metadata.ts_indices, dataset_metadata.ts_row_ranges, rd)
373372
self.logger.debug("Time series IDs have been successfully set.")
374373

375374
self._set_time_period(dataset_metadata.time_indices)
@@ -661,7 +660,7 @@ def _set_time_period(self, all_time_ids: np.ndarray) -> None:
661660
...
662661

663662
@abstractmethod
664-
def _set_ts(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray) -> None:
663+
def _set_ts(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray, rd: np.random.RandomState) -> None:
665664
"""Validates and filters the input time series IDs based on the `dataset` and `source_type`. This typically calls [`_process_ts_ids`](reference_dataset_config.md#references.DatasetConfig._process_ts_ids) for each time series ID filter. """
666665
...
667666

cesnet_tszoo/configs/disjoint_time_based_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,10 @@ def _set_time_period(self, all_time_ids: np.ndarray) -> None:
288288

289289
self._prepare_and_set_time_period_sets(all_time_ids, self.time_format)
290290

291-
def _set_ts(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray) -> None:
291+
def _set_ts(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray, rd: np.random.RandomState) -> None:
292292
""" Validates and filters inputted time series id from `train_ts`, `val_ts` and `test_ts` based on `dataset` and `source_type`. Handles random set."""
293293

294-
self._prepare_and_set_ts_sets(all_ts_ids, all_ts_row_ranges, self.ts_id_name, self.random_state)
294+
self._prepare_and_set_ts_sets(all_ts_ids, all_ts_row_ranges, self.ts_id_name, self.random_state, rd)
295295

296296
def _get_feature_transformers(self) -> Transformer:
297297
"""Creates transformer with `transformer_factory`. """

cesnet_tszoo/configs/handlers/series_based_handler.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ def __init__(self,
3131

3232
self.logger: Logger = logger
3333

34-
def _prepare_and_set_ts_sets(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray, ts_id_name: str, random_state) -> None:
34+
def _prepare_and_set_ts_sets(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray, ts_id_name: str, random_state: Optional[int], rd: np.random.RandomState) -> None:
3535
"""Validates and filters the input time series IDs based on the `dataset` and `source_type`. Handles random split."""
3636

3737
random_ts_ids = all_ts_ids[ts_id_name]
3838
random_indices = np.arange(len(all_ts_ids))
3939

4040
# Process train_ts if it was specified with times series ids
4141
if self.train_ts is not None and not isinstance(self.train_ts, (float, int)):
42-
self.train_ts, self.train_ts_row_ranges, _ = SeriesBasedHandler._process_ts_ids(self.train_ts, all_ts_ids, all_ts_row_ranges, None, None, self.logger, ts_id_name, random_state)
42+
self.train_ts, self.train_ts_row_ranges, _ = SeriesBasedHandler._process_ts_ids(self.train_ts, all_ts_ids, all_ts_row_ranges, None, None, self.logger, ts_id_name, random_state, rd)
4343

4444
mask = np.isin(random_ts_ids, self.train_ts, invert=True)
4545
random_ts_ids = random_ts_ids[mask]
@@ -49,7 +49,7 @@ def _prepare_and_set_ts_sets(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np
4949

5050
# Process val_ts if it was specified with times series ids
5151
if self.val_ts is not None and not isinstance(self.val_ts, (float, int)):
52-
self.val_ts, self.val_ts_row_ranges, _ = SeriesBasedHandler._process_ts_ids(self.val_ts, all_ts_ids, all_ts_row_ranges, None, None, self.logger, ts_id_name, random_state)
52+
self.val_ts, self.val_ts_row_ranges, _ = SeriesBasedHandler._process_ts_ids(self.val_ts, all_ts_ids, all_ts_row_ranges, None, None, self.logger, ts_id_name, random_state, rd)
5353

5454
mask = np.isin(random_ts_ids, self.val_ts, invert=True)
5555
random_ts_ids = random_ts_ids[mask]
@@ -59,7 +59,7 @@ def _prepare_and_set_ts_sets(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np
5959

6060
# Process time_ts if it was specified with times series ids
6161
if self.test_ts is not None and not isinstance(self.test_ts, (float, int)):
62-
self.test_ts, self.test_ts_row_ranges, _ = SeriesBasedHandler._process_ts_ids(self.test_ts, all_ts_ids, all_ts_row_ranges, None, None, self.logger, ts_id_name, random_state)
62+
self.test_ts, self.test_ts_row_ranges, _ = SeriesBasedHandler._process_ts_ids(self.test_ts, all_ts_ids, all_ts_row_ranges, None, None, self.logger, ts_id_name, random_state, rd)
6363

6464
mask = np.isin(random_ts_ids, self.test_ts, invert=True)
6565
random_ts_ids = random_ts_ids[mask]
@@ -80,23 +80,23 @@ def _prepare_and_set_ts_sets(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np
8080

8181
# Process random train_ts if it is to be randomly made
8282
if isinstance(self.train_ts, int):
83-
self.train_ts, self.train_ts_row_ranges, random_indices = SeriesBasedHandler._process_ts_ids(None, all_ts_ids, all_ts_row_ranges, self.train_ts, random_indices, self.logger, ts_id_name, random_state)
83+
self.train_ts, self.train_ts_row_ranges, random_indices = SeriesBasedHandler._process_ts_ids(None, all_ts_ids, all_ts_row_ranges, self.train_ts, random_indices, self.logger, ts_id_name, random_state, rd)
8484
self.logger.debug("Random train_ts set with %s time series.", self.train_ts)
8585

8686
# Process random val_ts if it is to be randomly made
8787
if isinstance(self.val_ts, int):
88-
self.val_ts, self.val_ts_row_ranges, random_indices = SeriesBasedHandler._process_ts_ids(None, all_ts_ids, all_ts_row_ranges, self.val_ts, random_indices, self.logger, ts_id_name, random_state)
88+
self.val_ts, self.val_ts_row_ranges, random_indices = SeriesBasedHandler._process_ts_ids(None, all_ts_ids, all_ts_row_ranges, self.val_ts, random_indices, self.logger, ts_id_name, random_state, rd)
8989
self.logger.debug("Random val_ts set with %s time series.", self.val_ts)
9090

9191
# Process random test_ts if it is to be randomly made
9292
if isinstance(self.test_ts, int):
93-
self.test_ts, self.test_ts_row_ranges, random_indices = SeriesBasedHandler._process_ts_ids(None, all_ts_ids, all_ts_row_ranges, self.test_ts, random_indices, self.logger, ts_id_name, random_state)
93+
self.test_ts, self.test_ts_row_ranges, random_indices = SeriesBasedHandler._process_ts_ids(None, all_ts_ids, all_ts_row_ranges, self.test_ts, random_indices, self.logger, ts_id_name, random_state, rd)
9494
self.logger.debug("Random test_ts set with %s time series.", self.test_ts)
9595

9696
if self.uses_all_ts:
9797
if self.train_ts is None and self.val_ts is None and self.test_ts is None:
9898
self.all_ts = all_ts_ids[ts_id_name]
99-
self.all_ts, self.all_ts_row_ranges, _ = SeriesBasedHandler._process_ts_ids(self.all_ts, all_ts_ids, all_ts_row_ranges, None, None, self.logger, ts_id_name, random_state)
99+
self.all_ts, self.all_ts_row_ranges, _ = SeriesBasedHandler._process_ts_ids(self.all_ts, all_ts_ids, all_ts_row_ranges, None, None, self.logger, ts_id_name, random_state, rd)
100100
self.logger.info("Using all time series for all_ts because train_ts, val_ts, and test_ts are all set to None.")
101101
else:
102102
for temp_ts_ids in [self.train_ts, self.val_ts, self.test_ts]:
@@ -114,7 +114,7 @@ def _prepare_and_set_ts_sets(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np
114114
if self.test_ts is not None:
115115
self.logger.debug("all_ts includes ids from test_ts.")
116116

117-
self.all_ts, self.all_ts_row_ranges, _ = self._process_ts_ids(self.all_ts, all_ts_ids, all_ts_row_ranges, None, None, self.logger, ts_id_name, random_state)
117+
self.all_ts, self.all_ts_row_ranges, _ = self._process_ts_ids(self.all_ts, all_ts_ids, all_ts_row_ranges, None, None, self.logger, ts_id_name, random_state, rd)
118118
else:
119119
self.all_ts = None
120120

@@ -163,7 +163,7 @@ def _validate_ts_overlap(self):
163163
raise ValueError("Train, Val, and Test can't have the same IDs.")
164164

165165
@staticmethod
166-
def _process_ts_ids(ts_ids: np.ndarray, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray, split_size: float | int | None, random_indices: np.ndarray, logger: Logger, ts_id_name: str, random_state) -> None:
166+
def _process_ts_ids(ts_ids: np.ndarray, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray, split_size: float | int | None, random_indices: np.ndarray, logger: Logger, ts_id_name: str, random_state: Optional[int], rd: np.random.RandomState) -> None:
167167
"""Validates and filters the input `ts_ids` based on the `dataset` and `source_type`. """
168168

169169
if ts_ids is None and split_size is None:
@@ -175,7 +175,7 @@ def _process_ts_ids(ts_ids: np.ndarray, all_ts_ids: np.ndarray, all_ts_row_range
175175
raise ValueError(f"Trying to use more time series than there are in the dataset. There are {len(all_ts_ids)} time series available.")
176176

177177
if split_size == len(random_indices):
178-
np.random.shuffle(random_indices)
178+
rd.shuffle(random_indices)
179179
ts_indices = random_indices
180180
ts_ids = all_ts_ids[ts_id_name][ts_indices]
181181
random_indices = np.array([]) # No remaining indices

cesnet_tszoo/configs/series_based_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,10 @@ def _set_time_period(self, all_time_ids: np.ndarray) -> None:
236236
self.time_period, self.display_time_period = TimeBasedHandler._process_time_period(self.time_period, all_time_ids, self.logger, self.time_format)
237237
self.logger.debug("Processed time_period: %s, display_time_period: %s", self.time_period, self.display_time_period)
238238

239-
def _set_ts(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray) -> None:
239+
def _set_ts(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray, rd: np.random.RandomState) -> None:
240240
"""Validates and filters the input time series IDs based on the `dataset` and `source_type`. Handles random split."""
241241

242-
self._prepare_and_set_ts_sets(all_ts_ids, all_ts_row_ranges, self.ts_id_name, self.random_state)
242+
self._prepare_and_set_ts_sets(all_ts_ids, all_ts_row_ranges, self.ts_id_name, self.random_state, rd)
243243

244244
def _get_feature_transformers(self) -> Transformer:
245245
"""Creates transformer with `transformer_factory`. """

cesnet_tszoo/configs/time_based_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,15 +283,15 @@ def _set_time_period(self, all_time_ids: np.ndarray) -> None:
283283

284284
self._prepare_and_set_time_period_sets(all_time_ids, self.time_format)
285285

286-
def _set_ts(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray) -> None:
286+
def _set_ts(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray, rd: np.random.RandomState) -> None:
287287
""" Validates and filters inputted time series id from `ts_ids` based on `dataset` and `source_type`. Handles random set."""
288288

289289
random_ts_ids = all_ts_ids[self.ts_id_name]
290290
random_indices = np.arange(len(all_ts_ids))
291291

292292
# Process ts_ids if it was specified with times series ids
293293
if not isinstance(self.ts_ids, (float, int)):
294-
self.ts_ids, self.ts_row_ranges, _ = SeriesBasedHandler._process_ts_ids(self.ts_ids, all_ts_ids, all_ts_row_ranges, None, None, self.logger, self.ts_id_name, self.random_state)
294+
self.ts_ids, self.ts_row_ranges, _ = SeriesBasedHandler._process_ts_ids(self.ts_ids, all_ts_ids, all_ts_row_ranges, None, None, self.logger, self.ts_id_name, self.random_state, rd)
295295

296296
mask = np.isin(random_ts_ids, self.ts_ids, invert=True)
297297
random_ts_ids = random_ts_ids[mask]
@@ -306,7 +306,7 @@ def _set_ts(self, all_ts_ids: np.ndarray, all_ts_row_ranges: np.ndarray) -> None
306306

307307
# Process random ts_ids if it is to be randomly made
308308
if isinstance(self.ts_ids, int):
309-
self.ts_ids, self.ts_row_ranges, random_indices = SeriesBasedHandler._process_ts_ids(None, all_ts_ids, all_ts_row_ranges, self.ts_ids, random_indices, self.logger, self.ts_id_name, self.random_state)
309+
self.ts_ids, self.ts_row_ranges, random_indices = SeriesBasedHandler._process_ts_ids(None, all_ts_ids, all_ts_row_ranges, self.ts_ids, random_indices, self.logger, self.ts_id_name, self.random_state, rd)
310310
self.logger.debug("Random ts_ids set with %s time series.", self.ts_ids)
311311

312312
def _get_feature_transformers(self) -> np.ndarray[Transformer] | Transformer:

cesnet_tszoo/datasets/cesnet_dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def set_dataset_config_and_initialize(self, dataset_config: DatasetConfig, displ
110110
display_config_details: Flag indicating whether and how to display the configuration values after initialization. `Default: text`
111111
workers: The number of workers to use during initialization. `Default: "config"`
112112
"""
113+
114+
if self.dataset_config is not None and self.dataset_config != dataset_config:
115+
raise ValueError("This dataset is already initialized with config. Create new dataset to configure with passed dataset_config!")
116+
113117
if display_config_details is not None:
114118
display_config_details = DisplayType(display_config_details)
115119

0 commit comments

Comments
 (0)