Skip to content

Commit 43b5f44

Browse files
feat: divide chunking and subset sampling strategies
1 parent fdab6c0 commit 43b5f44

File tree

14 files changed

+674
-124
lines changed

14 files changed

+674
-124
lines changed

examples/benchmark/run_amlb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class BenchmarkRunConfig:
6666

6767
def _default_sampling_config(seed: int) -> Dict[str, Any]:
6868
return {
69+
'strategy_kind': 'subset',
6970
'provider': 'sampling_zoo',
7071
'strategy': 'random',
7172
'strategy_params': {},

fedot/api/api_utils/api_composer.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from fedot.core.composer.composer_builder import ComposerBuilder
1919
from fedot.core.composer.gp_composer.gp_composer import GPComposer
2020
from fedot.core.constants import DEFAULT_TUNING_ITERATIONS_NUMBER
21-
from fedot.core.data.data import InputData
21+
from fedot.core.data.data import InputData, InputDataList
22+
from fedot.core.pipelines.pipeline_ensemble import PipelineEnsemble
2223
from fedot.core.pipelines.pipeline import Pipeline
2324
from fedot.core.pipelines.tuning.tuner_builder import TunerBuilder
2425
from fedot.core.repository.metrics_repository import MetricIDType
@@ -112,6 +113,23 @@ def obtain_model(self, train_data: InputData) -> Tuple[Pipeline, Sequence[Pipeli
112113
self.log.message('Model generation finished')
113114
return best_pipeline, best_pipeline_candidates, gp_composer.history
114115

116+
def obtain_ensemble_model(self, train_data_list: InputDataList) -> \
117+
Tuple[PipelineEnsemble, Sequence[Sequence[Pipeline]], List[OptHistory]]:
118+
pipelines: List[Pipeline] = []
119+
best_models: List[Sequence[Pipeline]] = []
120+
histories: List[OptHistory] = []
121+
122+
for chunk_data in train_data_list:
123+
pipeline, best_pipeline_candidates, history = self.obtain_model(chunk_data)
124+
if pipeline is None:
125+
raise ValueError('No models were found for one of the chunks')
126+
pipelines.append(pipeline)
127+
best_models.append(best_pipeline_candidates)
128+
histories.append(history)
129+
130+
ensemble = PipelineEnsemble(pipelines)
131+
return ensemble, best_models, histories
132+
115133
def propose_and_fit_initial_assumption(self, train_data: InputData) -> Tuple[Sequence[Pipeline], Pipeline]:
116134
""" Method for obtaining and fitting initial assumption"""
117135
available_operations = self.params.get('available_operations')
@@ -210,4 +228,4 @@ def tune_final_pipeline(self, train_data: InputData,
210228
tuned_pipeline = tuner.tune(pipeline_gp_composed)
211229
self.log.message('Hyperparameters tuning finished')
212230
self.was_tuned = tuner.was_tuned
213-
return tuned_pipeline
231+
return tuned_pipeline

fedot/api/api_utils/api_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from fedot.core.data.data import InputData, OutputData, data_type_is_table
1717
from fedot.core.data.data_preprocessing import convert_into_column
1818
from fedot.core.data.multi_modal import MultiModalData
19+
from fedot.core.pipelines.pipeline_ensemble import PipelineEnsemble
1920
from fedot.core.pipelines.pipeline import Pipeline
2021
from fedot.core.pipelines.ts_wrappers import in_sample_ts_forecast, convert_forecast_to_output
2122
from fedot.core.repository.tasks import Task, TaskTypesEnum
@@ -85,7 +86,8 @@ def define_data(self,
8586
data = self.preprocessor.obligatory_prepare_for_fit(data)
8687
return data
8788

88-
def define_predictions(self, current_pipeline: Pipeline, test_data: Union[InputData, MultiModalData],
89+
def define_predictions(self, current_pipeline: Union[Pipeline, PipelineEnsemble],
90+
test_data: Union[InputData, MultiModalData],
8991
in_sample: bool = False, validation_blocks: int = None) -> OutputData:
9092
""" Prepare predictions """
9193
forecast_length = getattr(test_data.task.task_params, 'forecast_length', None)

fedot/api/main.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@
2525
from fedot.api.api_utils.predefined_model import PredefinedModel
2626
from fedot.api.sampling_stage.executor import SamplingStageExecutor
2727
from fedot.core.constants import DEFAULT_API_TIMEOUT_MINUTES, DEFAULT_TUNING_ITERATIONS_NUMBER
28-
from fedot.core.data.data import InputData, OutputData, PathType
28+
from fedot.core.data.data import InputData, InputDataList, OutputData, PathType
2929
from fedot.core.data.multi_modal import MultiModalData
3030
from fedot.core.data.visualisation import plot_biplot, plot_forecast, plot_roc_auc
3131
from fedot.core.optimisers.objective import PipelineObjectiveEvaluate
3232
from fedot.core.optimisers.objective.metrics_objective import MetricsObjective
33+
from fedot.core.pipelines.pipeline_ensemble import PipelineEnsemble
3334
from fedot.core.pipelines.pipeline import Pipeline
3435
from fedot.core.pipelines.ts_wrappers import convert_forecast_to_output, out_of_sample_ts_forecast
3536
from fedot.core.pipelines.tuning.tuner_builder import TunerBuilder
@@ -118,13 +119,13 @@ def __init__(self,
118119
self.target: Optional[TargetType] = None
119120
self.prediction: Optional[OutputData] = None
120121
self._is_in_sample_prediction = True
121-
self.train_data: Optional[InputData] = None
122+
self.train_data: Optional[Union[InputData, InputDataList]] = None
122123
self.test_data: Optional[InputData] = None
123124

124125
# Outputs
125-
self.current_pipeline: Optional[Pipeline] = None
126-
self.best_models: Sequence[Pipeline] = ()
127-
self.history: Optional[OptHistory] = None
126+
self.current_pipeline: Optional[Union[Pipeline, PipelineEnsemble]] = None
127+
self.best_models: Sequence[Union[Pipeline, Sequence[Pipeline]]] = ()
128+
self.history: Optional[Union[OptHistory, Sequence[OptHistory]]] = None
128129
self.sampling_stage_metadata: Optional[dict] = None
129130

130131
fedot_composer_timer.reset_timer()
@@ -202,7 +203,12 @@ def fit(self,
202203
api_preprocessor=self.data_processor.preprocessor,
203204
).fit()
204205
else:
205-
self.current_pipeline, self.best_models, self.history = self.api_composer.obtain_model(self.train_data)
206+
if isinstance(self.train_data, InputDataList):
207+
self.current_pipeline, self.best_models, self.history = \
208+
self.api_composer.obtain_ensemble_model(self.train_data)
209+
else:
210+
self.current_pipeline, self.best_models, self.history = \
211+
self.api_composer.obtain_model(self.train_data)
206212

207213
if self.current_pipeline is None:
208214
raise ValueError('No models were found')
@@ -219,13 +225,20 @@ def fit(self,
219225
self.log.message('Already fitted initial pipeline is used')
220226

221227
# Merge API & pipelines encoders if it is required
222-
self.current_pipeline.preprocessor = BasePreprocessor.merge_preprocessors(
228+
merged_preprocessor = BasePreprocessor.merge_preprocessors(
223229
api_preprocessor=self.data_processor.preprocessor,
224230
pipeline_preprocessor=self.current_pipeline.preprocessor,
225231
use_auto_preprocessing=self.params.get('use_auto_preprocessing')
226232
)
233+
self.current_pipeline.preprocessor = merged_preprocessor
234+
if isinstance(self.current_pipeline, PipelineEnsemble):
235+
for pipeline in self.current_pipeline.pipelines:
236+
pipeline.preprocessor = merged_preprocessor
227237

228-
self.log.message(f'Final pipeline: {graph_structure(self.current_pipeline)}')
238+
if isinstance(self.current_pipeline, Pipeline):
239+
self.log.message(f'Final pipeline: {graph_structure(self.current_pipeline)}')
240+
else:
241+
self.log.message(f'Final pipeline ensemble: {len(self.current_pipeline.pipelines)} pipelines')
229242

230243
return self.current_pipeline
231244
finally:
@@ -258,6 +271,8 @@ def tune(self,
258271
"""
259272
if self.current_pipeline is None:
260273
raise ValueError(NOT_FITTED_ERR_MSG)
274+
if isinstance(self.current_pipeline, PipelineEnsemble):
275+
raise ValueError('Tuning for pipeline ensembles is not supported yet.')
261276

262277
with fedot_composer_timer.launch_tuning('post'):
263278
tune_plan = build_tune_execution_plan(
@@ -618,15 +633,18 @@ def _run_sampling_stage_if_necessary(self):
618633
)
619634

620635
def _train_pipeline_on_full_dataset(self, recommendations: Optional[dict],
621-
full_train_not_preprocessed: Union[InputData, MultiModalData]):
636+
full_train_not_preprocessed: Union[InputData, InputDataList, MultiModalData]):
622637
"""Applies training procedure for obtained pipeline if dataset was clipped
623638
"""
624639

625640
if recommendations is not None:
626641
# if data was cut we need to refit pipeline on full data
627-
self.data_processor.accept_and_apply_recommendations(full_train_not_preprocessed,
628-
{k: v for k, v in recommendations.items()
629-
if k != 'cut'})
642+
cleaned_recommendations = {k: v for k, v in recommendations.items() if k != 'cut'}
643+
if isinstance(full_train_not_preprocessed, list):
644+
for chunk_data in full_train_not_preprocessed:
645+
self.data_processor.accept_and_apply_recommendations(chunk_data, cleaned_recommendations)
646+
else:
647+
self.data_processor.accept_and_apply_recommendations(full_train_not_preprocessed, cleaned_recommendations)
630648
self.current_pipeline.fit(
631649
full_train_not_preprocessed,
632650
n_jobs=self.params.n_jobs

fedot/api/sampling_stage/__init__.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,28 @@
1-
from fedot.api.sampling_stage.config import SamplingConfig, validate_sampling_config
1+
from fedot.api.sampling_stage.config import (
2+
SamplingChunkingConfig,
3+
SamplingConfig,
4+
SamplingConfigBase,
5+
SamplingSubsetConfig,
6+
validate_sampling_config,
7+
)
28
from fedot.api.sampling_stage.executor import SamplingStageExecutor, SamplingStageOutput
3-
from fedot.api.sampling_stage.providers import SamplingProvider, SamplingProviderResult, SamplingZooProvider
9+
from fedot.api.sampling_stage.providers import (
10+
SamplingProvider,
11+
SamplingProviderResult,
12+
SamplingSubsetResult,
13+
SamplingChunkingResult,
14+
SamplingZooProvider,
15+
)
416

517
__all__ = [
618
'SamplingConfig',
19+
'SamplingConfigBase',
20+
'SamplingChunkingConfig',
21+
'SamplingSubsetConfig',
722
'SamplingProvider',
823
'SamplingProviderResult',
24+
'SamplingSubsetResult',
25+
'SamplingChunkingResult',
926
'SamplingStageExecutor',
1027
'SamplingStageOutput',
1128
'SamplingZooProvider',

0 commit comments

Comments
 (0)