2525from fedot .api .api_utils .predefined_model import PredefinedModel
2626from fedot .api .sampling_stage .executor import SamplingStageExecutor
2727from fedot .core .constants import DEFAULT_API_TIMEOUT_MINUTES , DEFAULT_TUNING_ITERATIONS_NUMBER
28- from fedot .core .data .data import InputData , OutputData , PathType
28+ from fedot .core .data .data import InputData , InputDataList , OutputData , PathType
2929from fedot .core .data .multi_modal import MultiModalData
3030from fedot .core .data .visualisation import plot_biplot , plot_forecast , plot_roc_auc
3131from fedot .core .optimisers .objective import PipelineObjectiveEvaluate
3232from fedot .core .optimisers .objective .metrics_objective import MetricsObjective
33+ from fedot .core .pipelines .pipeline_ensemble import PipelineEnsemble
3334from fedot .core .pipelines .pipeline import Pipeline
3435from fedot .core .pipelines .ts_wrappers import convert_forecast_to_output , out_of_sample_ts_forecast
3536from fedot .core .pipelines .tuning .tuner_builder import TunerBuilder
@@ -118,13 +119,13 @@ def __init__(self,
118119 self .target : Optional [TargetType ] = None
119120 self .prediction : Optional [OutputData ] = None
120121 self ._is_in_sample_prediction = True
121- self .train_data : Optional [InputData ] = None
122+ self .train_data : Optional [Union [ InputData , InputDataList ] ] = None
122123 self .test_data : Optional [InputData ] = None
123124
124125 # Outputs
125- self .current_pipeline : Optional [Pipeline ] = None
126- self .best_models : Sequence [Pipeline ] = ()
127- self .history : Optional [OptHistory ] = None
126+ self .current_pipeline : Optional [Union [ Pipeline , PipelineEnsemble ] ] = None
127+ self .best_models : Sequence [Union [ Pipeline , Sequence [ Pipeline ]] ] = ()
128+ self .history : Optional [Union [ OptHistory , Sequence [ OptHistory ]] ] = None
128129 self .sampling_stage_metadata : Optional [dict ] = None
129130
130131 fedot_composer_timer .reset_timer ()
@@ -202,7 +203,12 @@ def fit(self,
202203 api_preprocessor = self .data_processor .preprocessor ,
203204 ).fit ()
204205 else :
205- self .current_pipeline , self .best_models , self .history = self .api_composer .obtain_model (self .train_data )
206+ if isinstance (self .train_data , InputDataList ):
207+ self .current_pipeline , self .best_models , self .history = \
208+ self .api_composer .obtain_ensemble_model (self .train_data )
209+ else :
210+ self .current_pipeline , self .best_models , self .history = \
211+ self .api_composer .obtain_model (self .train_data )
206212
207213 if self .current_pipeline is None :
208214 raise ValueError ('No models were found' )
@@ -219,13 +225,20 @@ def fit(self,
219225 self .log .message ('Already fitted initial pipeline is used' )
220226
221227 # Merge API & pipelines encoders if it is required
222- self . current_pipeline . preprocessor = BasePreprocessor .merge_preprocessors (
228+ merged_preprocessor = BasePreprocessor .merge_preprocessors (
223229 api_preprocessor = self .data_processor .preprocessor ,
224230 pipeline_preprocessor = self .current_pipeline .preprocessor ,
225231 use_auto_preprocessing = self .params .get ('use_auto_preprocessing' )
226232 )
233+ self .current_pipeline .preprocessor = merged_preprocessor
234+ if isinstance (self .current_pipeline , PipelineEnsemble ):
235+ for pipeline in self .current_pipeline .pipelines :
236+ pipeline .preprocessor = merged_preprocessor
227237
228- self .log .message (f'Final pipeline: { graph_structure (self .current_pipeline )} ' )
238+ if isinstance (self .current_pipeline , Pipeline ):
239+ self .log .message (f'Final pipeline: { graph_structure (self .current_pipeline )} ' )
240+ else :
241+ self .log .message (f'Final pipeline ensemble: { len (self .current_pipeline .pipelines )} pipelines' )
229242
230243 return self .current_pipeline
231244 finally :
@@ -258,6 +271,8 @@ def tune(self,
258271 """
259272 if self .current_pipeline is None :
260273 raise ValueError (NOT_FITTED_ERR_MSG )
274+ if isinstance (self .current_pipeline , PipelineEnsemble ):
275+ raise ValueError ('Tuning for pipeline ensembles is not supported yet.' )
261276
262277 with fedot_composer_timer .launch_tuning ('post' ):
263278 tune_plan = build_tune_execution_plan (
@@ -618,15 +633,18 @@ def _run_sampling_stage_if_necessary(self):
618633 )
619634
620635 def _train_pipeline_on_full_dataset (self , recommendations : Optional [dict ],
621- full_train_not_preprocessed : Union [InputData , MultiModalData ]):
636+ full_train_not_preprocessed : Union [InputData , InputDataList , MultiModalData ]):
622637 """Applies training procedure for obtained pipeline if dataset was clipped
623638 """
624639
625640 if recommendations is not None :
626641 # if data was cut we need to refit pipeline on full data
627- self .data_processor .accept_and_apply_recommendations (full_train_not_preprocessed ,
628- {k : v for k , v in recommendations .items ()
629- if k != 'cut' })
642+ cleaned_recommendations = {k : v for k , v in recommendations .items () if k != 'cut' }
643+ if isinstance (full_train_not_preprocessed , list ):
644+ for chunk_data in full_train_not_preprocessed :
645+ self .data_processor .accept_and_apply_recommendations (chunk_data , cleaned_recommendations )
646+ else :
647+ self .data_processor .accept_and_apply_recommendations (full_train_not_preprocessed , cleaned_recommendations )
630648 self .current_pipeline .fit (
631649 full_train_not_preprocessed ,
632650 n_jobs = self .params .n_jobs
0 commit comments