1+ from dataclasses import dataclass
2+ from typing import Dict , Optional
3+
14import numpy as np
25import pytest
6+ from sklearn .ensemble import RandomForestClassifier
37
48from fedot import Fedot
59from fedot .api .sampling_stage .executor import SamplingStageExecutor , SamplingStageOutput
6- from fedot .api .sampling_stage .providers import SamplingProvider , SamplingSubsetResult
10+ from fedot .api .sampling_stage .providers import SamplingProvider , SamplingSubsetResult , SamplingZooProvider
11+ from fedot .core .pipelines .pipeline_ensemble import PipelineEnsemble
12+ from fedot .core .pipelines .pipeline import Pipeline
713from fedot .core .repository .tasks import TsForecastingParams
814from test .data .datasets import get_dataset
915
@@ -25,7 +31,7 @@ def sample(self,
2531 target = np .asarray (target ).reshape (- 1 )
2632 for label in np .unique (target ):
2733 label_idx = np .where (target == label )[0 ]
28- k = max (1 , int (round (len (label_idx ) * ratio )))
34+ k = max (1 , int (round (len (label_idx ) * injectable_params [ ' ratio' ] )))
2935 picked = rng .choice (label_idx , size = min (k , len (label_idx )), replace = False )
3036 indices .extend (picked .tolist ())
3137
@@ -35,6 +41,201 @@ def sample(self,
3541 meta = {'provider' : 'stratified_stub' })
3642
3743
44+ @dataclass (frozen = True )
45+ class StrategySpec :
46+ name : str
47+ kind : str
48+ task_type : str
49+ strategy_params : Dict [str , object ]
50+ skip_reason : Optional [str ] = None
51+
52+
53+ SAMPLING_STRATEGY_SPECS = [
54+ StrategySpec (
55+ name = 'random' ,
56+ kind = 'chunking' ,
57+ task_type = 'classification' ,
58+ strategy_params = {'n_partitions' : 10 },
59+ ),
60+ StrategySpec (
61+ name = 'stratified' ,
62+ kind = 'chunking' ,
63+ task_type = 'classification' ,
64+ strategy_params = {'n_partitions' : 10 },
65+ ),
66+ StrategySpec (
67+ name = 'advanced_stratified' ,
68+ kind = 'chunking' ,
69+ task_type = 'classification' ,
70+ strategy_params = {'n_partitions' : 10 },
71+ ),
72+ StrategySpec (
73+ name = 'regression_stratified' ,
74+ kind = 'chunking' ,
75+ task_type = 'regression' ,
76+ strategy_params = {
77+ 'n_bins' : 5 ,
78+ 'encode' : 'ordinal' ,
79+ 'strategy' : 'quantile' ,
80+ 'n_partitions' : 10 ,
81+ 'use_advanced' : True ,
82+ },
83+ ),
84+ StrategySpec (
85+ name = 'temporal' ,
86+ kind = 'chunking' ,
87+ task_type = 'ts_forecasting' ,
88+ strategy_params = {},
89+ skip_reason = 'Temporal strategies are not supported by sampling stage yet.' ,
90+ ),
91+ StrategySpec (
92+ name = 'difficulty' ,
93+ kind = 'chunking' ,
94+ task_type = 'classification' ,
95+ strategy_params = {
96+ 'difficulty_threshold' : 0.5 ,
97+ 'difficulty_metric' : 'f1' ,
98+ 'n_partitions' : 10 ,
99+ 'problem' : 'classification' ,
100+ 'model' : RandomForestClassifier (n_estimators = 10 , random_state = 42 ),
101+ 'chunks_percent' : 50 ,
102+ },
103+ ),
104+ StrategySpec (
105+ name = 'uncertainty' ,
106+ kind = 'chunking' ,
107+ task_type = 'classification' ,
108+ strategy_params = {
109+ 'uncertainty_threshold' : 0.5 ,
110+ 'n_partitions' : 10 ,
111+ 'problem' : 'classification' ,
112+ 'model' : RandomForestClassifier (n_estimators = 10 , random_state = 42 ),
113+ 'chunks_percent' : 50 ,
114+ },
115+ ),
116+ StrategySpec (
117+ name = 'balance' ,
118+ kind = 'chunking' ,
119+ task_type = 'classification' ,
120+ strategy_params = {
121+ 'n_partitions' : 10 ,
122+ 'balance_method' : 'random' ,
123+ 'balancer_kwargs' : {},
124+ },
125+ ),
126+ StrategySpec (
127+ name = 'feature_clustering' ,
128+ kind = 'chunking' ,
129+ task_type = 'classification' ,
130+ strategy_params = {
131+ 'n_partitions' : 10 ,
132+ 'method' : 'kmeans' ,
133+ 'feature_engineering' : False ,
134+ },
135+ ),
136+ StrategySpec (
137+ name = 'tsne_clustering' ,
138+ kind = 'chunking' ,
139+ task_type = 'classification' ,
140+ strategy_params = {
141+ 'n_components' : 2 ,
142+ 'perplexity' : 5 ,
143+ },
144+ ),
145+ StrategySpec (
146+ name = 'delaunay' ,
147+ kind = 'chunking' ,
148+ task_type = 'classification' ,
149+ strategy_params = {
150+ 'n_partitions' : 10 ,
151+ 'n_clusters' : 2 ,
152+ 'emptiness_threshold' : 0.1 ,
153+ 'dim_reduction_method' : 'pca' ,
154+ 'dim_reduction_target' : 2 ,
155+ },
156+ ),
157+ StrategySpec (
158+ name = 'hdbscan' ,
159+ kind = 'chunking' ,
160+ task_type = 'classification' ,
161+ strategy_params = {
162+ 'min_cluster_size' : 5 ,
163+ 'one_cluster' : True ,
164+ 'prob_threshold' : 0.5 ,
165+ 'all_points' : True ,
166+ },
167+ ),
168+ StrategySpec (
169+ name = 'voronoi' ,
170+ kind = 'chunking' ,
171+ task_type = 'classification' ,
172+ strategy_params = {
173+ 'n_partitions' : 10 ,
174+ 'emptiness_threshold' : 0.1 ,
175+ },
176+ ),
177+ StrategySpec (
178+ name = 'spectral_leverage' ,
179+ kind = 'subset' ,
180+ task_type = 'classification' ,
181+ strategy_params = {},
182+ ),
183+ StrategySpec (
184+ name = 'tensor_energy' ,
185+ kind = 'subset' ,
186+ task_type = 'classification' ,
187+ strategy_params = {},
188+ ),
189+ StrategySpec (
190+ name = 'kernel' ,
191+ kind = 'subset' ,
192+ task_type = 'classification' ,
193+ strategy_params = {},
194+ ),
195+ ]
196+
197+
198+ def _sampling_zoo_available () -> bool :
199+ try :
200+ SamplingZooProvider ().load_factory ()
201+ return True
202+ except ModuleNotFoundError :
203+ return False
204+
205+
206+ @pytest .fixture (scope = 'session' )
207+ def sampling_zoo_available ():
208+ if not _sampling_zoo_available ():
209+ pytest .skip ('Sampling Zoo dependency is not available.' )
210+
211+
212+ @pytest .fixture (scope = 'session' )
213+ def classification_train_data ():
214+ train_data , _ , _ = get_dataset ('classification' , n_samples = 10000 , n_features = 6 , iris_dataset = False )
215+ return train_data
216+
217+
218+ @pytest .fixture (scope = 'session' )
219+ def regression_train_data ():
220+ train_data , _ , _ = get_dataset ('regression' , n_samples = 10000 , n_features = 6 , iris_dataset = False )
221+ return train_data
222+
223+
224+ def _build_sampling_config (spec : StrategySpec ) -> Dict [str , object ]:
225+ config : Dict [str , object ] = {
226+ 'strategy_kind' : spec .kind ,
227+ 'provider' : 'sampling_zoo' ,
228+ 'strategy' : spec .name ,
229+ 'strategy_params' : spec .strategy_params ,
230+ }
231+ if spec .kind == 'subset' :
232+ config .update ({
233+ 'candidate_ratios' : [0.5 ],
234+ 'delta_metric_threshold' : 1.0 ,
235+ })
236+ return config
237+
238+
38239def test_fit_with_sampling_config_none_preserves_default_behavior ():
39240 train_data , _ , _ = get_dataset ('classification' , n_samples = 120 , n_features = 6 , iris_dataset = False )
40241
@@ -200,6 +401,41 @@ def test_fail_fast_for_multimodal_input_with_sampling_stage():
200401 model .fit (features = data , target = target )
201402
202403
404+ @pytest .mark .integration
405+ @pytest .mark .slow
406+ @pytest .mark .parametrize ('spec' , SAMPLING_STRATEGY_SPECS , ids = lambda spec : f'{ spec .kind } :{ spec .name } ' )
407+ def test_sampling_stage_runs_all_strategies (spec : StrategySpec ,
408+ sampling_zoo_available ,
409+ classification_train_data ,
410+ regression_train_data ):
411+ if spec .skip_reason :
412+ pytest .skip (spec .skip_reason )
413+
414+ train_data = classification_train_data if spec .task_type == 'classification' else regression_train_data
415+ sampling_config = _build_sampling_config (spec )
416+
417+ model = Fedot (problem = spec .task_type ,
418+ timeout = 0.2 ,
419+ preset = 'fast_train' ,
420+ max_depth = 1 ,
421+ max_arity = 2 ,
422+ sampling_config = sampling_config )
423+
424+ try :
425+ pipeline = model .fit (features = train_data )
426+ except (ModuleNotFoundError , ImportError ) as exc :
427+ pytest .skip (str (exc ))
428+
429+ assert pipeline is not None
430+ assert model .sampling_stage_metadata is not None
431+ assert model .sampling_stage_metadata ['status' ] == 'applied'
432+ if spec .kind == 'chunking' :
433+ assert isinstance (model .current_pipeline , PipelineEnsemble )
434+ assert isinstance (model .train_data , list )
435+ else :
436+ assert isinstance (model .current_pipeline , Pipeline )
437+
438+
203439def test_timeout_restored_after_sampling_stage_real_path (monkeypatch ):
204440 train_data , _ , _ = get_dataset ('classification' , n_samples = 90 , n_features = 6 , iris_dataset = False )
205441
@@ -235,4 +471,3 @@ def fake_execute(self, train_data_input):
235471 assert model .sampling_stage_metadata is not None
236472 assert model .sampling_stage_metadata ['status' ] == 'applied'
237473 assert model .params .timeout == pytest .approx (0.2 )
238-
0 commit comments