Skip to content

Commit 3f64add

Browse files
committed
Fix: Fixed invalid transformer training on all sets
1 parent 045587b commit 3f64add

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

cesnet_tszoo/datasets/disjoint_time_based_cesnet_dataset.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,9 @@ def _initialize_transformers_and_details(self, workers: int) -> None:
434434
all_ts_ids_to_take = np.array([])
435435

436436
if self.dataset_config.has_train():
437+
can_fit_transformers = self.dataset_config.transform_with is not None and (not self.dataset_config.are_transformers_premade or self.dataset_config.partial_fit_initialized_transformers)
437438
updated_ts_row_ranges, updated_ts_ids, updated_fillers, updated_anomaly_handlers = self.__initialize_transformers_and_details_for_set(self.dataset_config.train_ts, self.dataset_config.train_ts_row_ranges, self.dataset_config.train_time_period,
438-
self.dataset_config.train_fillers, self.dataset_config.anomaly_handlers, workers, "train")
439+
self.dataset_config.train_fillers, self.dataset_config.anomaly_handlers, workers, "train", can_fit_transformers)
439440
self.dataset_config.train_ts = updated_ts_ids
440441
self.dataset_config.train_ts_row_ranges = updated_ts_row_ranges
441442
self.dataset_config.train_fillers = updated_fillers
@@ -447,7 +448,7 @@ def _initialize_transformers_and_details(self, workers: int) -> None:
447448

448449
if self.dataset_config.has_val():
449450
updated_ts_row_ranges, updated_ts_ids, updated_fillers, _ = self.__initialize_transformers_and_details_for_set(self.dataset_config.val_ts, self.dataset_config.val_ts_row_ranges, self.dataset_config.val_time_period,
450-
self.dataset_config.val_fillers, None, workers, "val")
451+
self.dataset_config.val_fillers, None, workers, "val", False)
451452
self.dataset_config.val_ts = updated_ts_ids
452453
self.dataset_config.val_ts_row_ranges = updated_ts_row_ranges
453454
self.dataset_config.val_fillers = updated_fillers
@@ -458,7 +459,7 @@ def _initialize_transformers_and_details(self, workers: int) -> None:
458459

459460
if self.dataset_config.has_test():
460461
updated_ts_row_ranges, updated_ts_ids, updated_fillers, _ = self.__initialize_transformers_and_details_for_set(self.dataset_config.test_ts, self.dataset_config.test_ts_row_ranges, self.dataset_config.test_time_period,
461-
self.dataset_config.test_fillers, None, workers, "test")
462+
self.dataset_config.test_fillers, None, workers, "test", False)
462463
self.dataset_config.test_ts = updated_ts_ids
463464
self.dataset_config.test_ts_row_ranges = updated_ts_row_ranges
464465
self.dataset_config.test_fillers = updated_fillers
@@ -539,7 +540,7 @@ def _get_dataloader(self, dataset: SplittedDataset, workers: int | Literal["conf
539540

540541
return self._get_time_based_dataloader(dataset, workers, take_all, batch_size)
541542

542-
def __initialize_transformers_and_details_for_set(self, ts_ids, ts_row_ranges, time_period, fillers, anomaly_handlers, workers, set_name):
543+
def __initialize_transformers_and_details_for_set(self, ts_ids, ts_row_ranges, time_period, fillers, anomaly_handlers, workers, set_name, can_fit_transformers):
543544
"""Initializes transformers and details for provided time series. """
544545
init_dataset = DisjointTimeBasedInitializerDataset(self.dataset_path,
545546
self.dataset_config._get_table_data_path(),
@@ -571,12 +572,8 @@ def __initialize_transformers_and_details_for_set(self, ts_ids, ts_row_ranges, t
571572
ts_ids_to_take.append(i)
572573

573574
# Fit transformers if required
574-
if self.dataset_config.transform_with is not None and data is not None and (not self.dataset_config.are_transformers_premade or self.dataset_config.partial_fit_initialized_transformers):
575-
576-
if self.dataset_config.are_transformers_premade and self.dataset_config.partial_fit_initialized_transformers:
577-
self.dataset_config.transformers.partial_fit(data)
578-
else:
579-
self.dataset_config.transformers.partial_fit(data)
575+
if can_fit_transformers and data is not None:
576+
self.dataset_config.transformers.partial_fit(data)
580577

581578
# Sets fitted anomaly handlers
582579
if self.dataset_config.handle_anomalies_with is not None and anomaly_handler is not None:

0 commit comments

Comments
 (0)