@@ -434,8 +434,9 @@ def _initialize_transformers_and_details(self, workers: int) -> None:
434434 all_ts_ids_to_take = np .array ([])
435435
436436 if self .dataset_config .has_train ():
437+ can_fit_transformers = self .dataset_config .transform_with is not None and (not self .dataset_config .are_transformers_premade or self .dataset_config .partial_fit_initialized_transformers )
437438 updated_ts_row_ranges , updated_ts_ids , updated_fillers , updated_anomaly_handlers = self .__initialize_transformers_and_details_for_set (self .dataset_config .train_ts , self .dataset_config .train_ts_row_ranges , self .dataset_config .train_time_period ,
438- self .dataset_config .train_fillers , self .dataset_config .anomaly_handlers , workers , "train" )
439+ self .dataset_config .train_fillers , self .dataset_config .anomaly_handlers , workers , "train" , can_fit_transformers )
439440 self .dataset_config .train_ts = updated_ts_ids
440441 self .dataset_config .train_ts_row_ranges = updated_ts_row_ranges
441442 self .dataset_config .train_fillers = updated_fillers
@@ -447,7 +448,7 @@ def _initialize_transformers_and_details(self, workers: int) -> None:
447448
448449 if self .dataset_config .has_val ():
449450 updated_ts_row_ranges , updated_ts_ids , updated_fillers , _ = self .__initialize_transformers_and_details_for_set (self .dataset_config .val_ts , self .dataset_config .val_ts_row_ranges , self .dataset_config .val_time_period ,
450- self .dataset_config .val_fillers , None , workers , "val" )
451+ self .dataset_config .val_fillers , None , workers , "val" , False )
451452 self .dataset_config .val_ts = updated_ts_ids
452453 self .dataset_config .val_ts_row_ranges = updated_ts_row_ranges
453454 self .dataset_config .val_fillers = updated_fillers
@@ -458,7 +459,7 @@ def _initialize_transformers_and_details(self, workers: int) -> None:
458459
459460 if self .dataset_config .has_test ():
460461 updated_ts_row_ranges , updated_ts_ids , updated_fillers , _ = self .__initialize_transformers_and_details_for_set (self .dataset_config .test_ts , self .dataset_config .test_ts_row_ranges , self .dataset_config .test_time_period ,
461- self .dataset_config .test_fillers , None , workers , "test" )
462+ self .dataset_config .test_fillers , None , workers , "test" , False )
462463 self .dataset_config .test_ts = updated_ts_ids
463464 self .dataset_config .test_ts_row_ranges = updated_ts_row_ranges
464465 self .dataset_config .test_fillers = updated_fillers
@@ -539,7 +540,7 @@ def _get_dataloader(self, dataset: SplittedDataset, workers: int | Literal["conf
539540
540541 return self ._get_time_based_dataloader (dataset , workers , take_all , batch_size )
541542
542- def __initialize_transformers_and_details_for_set (self , ts_ids , ts_row_ranges , time_period , fillers , anomaly_handlers , workers , set_name ):
543+ def __initialize_transformers_and_details_for_set (self , ts_ids , ts_row_ranges , time_period , fillers , anomaly_handlers , workers , set_name , can_fit_transformers ):
543544 """Initializes transformers and details for provided time series. """
544545 init_dataset = DisjointTimeBasedInitializerDataset (self .dataset_path ,
545546 self .dataset_config ._get_table_data_path (),
@@ -571,12 +572,8 @@ def __initialize_transformers_and_details_for_set(self, ts_ids, ts_row_ranges, t
571572 ts_ids_to_take .append (i )
572573
573574 # Fit transformers if required
574- if self .dataset_config .transform_with is not None and data is not None and (not self .dataset_config .are_transformers_premade or self .dataset_config .partial_fit_initialized_transformers ):
575-
576- if self .dataset_config .are_transformers_premade and self .dataset_config .partial_fit_initialized_transformers :
577- self .dataset_config .transformers .partial_fit (data )
578- else :
579- self .dataset_config .transformers .partial_fit (data )
575+ if can_fit_transformers and data is not None :
576+ self .dataset_config .transformers .partial_fit (data )
580577
581578 # Sets fitted anomaly handlers
582579 if self .dataset_config .handle_anomalies_with is not None and anomaly_handler is not None :
0 commit comments