Skip to content

Commit 66be6fa

Browse files
test: add all sampling zoo strategies tests
1 parent 43b5f44 commit 66be6fa

File tree

4 files changed

+243
-7
lines changed

4 files changed

+243
-7
lines changed

fedot/api/sampling_stage/executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _execute_chunking(self,
134134
timeout_after_stage = self._compute_updated_timeout(elapsed_seconds)
135135

136136
metadata = {
137-
'status': 'chunking',
137+
'status': 'applied',
138138
'provider': self.config.provider,
139139
'strategy': self.config.strategy,
140140
'rows_before': int(len(train_data.idx)),

fedot/api/sampling_stage/providers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class SamplingZooProvider(SamplingProvider):
4545
)
4646

4747
def __init__(self):
48-
self._factory_cls = self._load_factory()
48+
self._factory_cls = self.load_factory()
4949

5050
def sample(self,
5151
features: np.ndarray,
@@ -457,7 +457,7 @@ def _inject_required_kwargs(factory: Any,
457457

458458
return updated_kwargs
459459

460-
def _load_factory(self):
460+
def load_factory(self):
461461
for module_name in self._SAMPLING_MODULE_CANDIDATES:
462462
try:
463463
module = import_module(module_name)

test/integration/api/test_sampling_stage_integration.py

Lines changed: 238 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
from dataclasses import dataclass
2+
from typing import Dict, Optional
3+
14
import numpy as np
25
import pytest
6+
from sklearn.ensemble import RandomForestClassifier
37

48
from fedot import Fedot
59
from fedot.api.sampling_stage.executor import SamplingStageExecutor, SamplingStageOutput
6-
from fedot.api.sampling_stage.providers import SamplingProvider, SamplingSubsetResult
10+
from fedot.api.sampling_stage.providers import SamplingProvider, SamplingSubsetResult, SamplingZooProvider
11+
from fedot.core.pipelines.pipeline_ensemble import PipelineEnsemble
12+
from fedot.core.pipelines.pipeline import Pipeline
713
from fedot.core.repository.tasks import TsForecastingParams
814
from test.data.datasets import get_dataset
915

@@ -25,7 +31,7 @@ def sample(self,
2531
target = np.asarray(target).reshape(-1)
2632
for label in np.unique(target):
2733
label_idx = np.where(target == label)[0]
28-
k = max(1, int(round(len(label_idx) * ratio)))
34+
k = max(1, int(round(len(label_idx) * injectable_params['ratio'])))
2935
picked = rng.choice(label_idx, size=min(k, len(label_idx)), replace=False)
3036
indices.extend(picked.tolist())
3137

@@ -35,6 +41,201 @@ def sample(self,
3541
meta={'provider': 'stratified_stub'})
3642

3743

44+
@dataclass(frozen=True)
45+
class StrategySpec:
46+
name: str
47+
kind: str
48+
task_type: str
49+
strategy_params: Dict[str, object]
50+
skip_reason: Optional[str] = None
51+
52+
53+
SAMPLING_STRATEGY_SPECS = [
54+
StrategySpec(
55+
name='random',
56+
kind='chunking',
57+
task_type='classification',
58+
strategy_params={'n_partitions': 10},
59+
),
60+
StrategySpec(
61+
name='stratified',
62+
kind='chunking',
63+
task_type='classification',
64+
strategy_params={'n_partitions': 10},
65+
),
66+
StrategySpec(
67+
name='advanced_stratified',
68+
kind='chunking',
69+
task_type='classification',
70+
strategy_params={'n_partitions': 10},
71+
),
72+
StrategySpec(
73+
name='regression_stratified',
74+
kind='chunking',
75+
task_type='regression',
76+
strategy_params={
77+
'n_bins': 5,
78+
'encode': 'ordinal',
79+
'strategy': 'quantile',
80+
'n_partitions': 10,
81+
'use_advanced': True,
82+
},
83+
),
84+
StrategySpec(
85+
name='temporal',
86+
kind='chunking',
87+
task_type='ts_forecasting',
88+
strategy_params={},
89+
skip_reason='Temporal strategies are not supported by sampling stage yet.',
90+
),
91+
StrategySpec(
92+
name='difficulty',
93+
kind='chunking',
94+
task_type='classification',
95+
strategy_params={
96+
'difficulty_threshold': 0.5,
97+
'difficulty_metric': 'f1',
98+
'n_partitions': 10,
99+
'problem': 'classification',
100+
'model': RandomForestClassifier(n_estimators=10, random_state=42),
101+
'chunks_percent': 50,
102+
},
103+
),
104+
StrategySpec(
105+
name='uncertainty',
106+
kind='chunking',
107+
task_type='classification',
108+
strategy_params={
109+
'uncertainty_threshold': 0.5,
110+
'n_partitions': 10,
111+
'problem': 'classification',
112+
'model': RandomForestClassifier(n_estimators=10, random_state=42),
113+
'chunks_percent': 50,
114+
},
115+
),
116+
StrategySpec(
117+
name='balance',
118+
kind='chunking',
119+
task_type='classification',
120+
strategy_params={
121+
'n_partitions': 10,
122+
'balance_method': 'random',
123+
'balancer_kwargs': {},
124+
},
125+
),
126+
StrategySpec(
127+
name='feature_clustering',
128+
kind='chunking',
129+
task_type='classification',
130+
strategy_params={
131+
'n_partitions': 10,
132+
'method': 'kmeans',
133+
'feature_engineering': False,
134+
},
135+
),
136+
StrategySpec(
137+
name='tsne_clustering',
138+
kind='chunking',
139+
task_type='classification',
140+
strategy_params={
141+
'n_components': 2,
142+
'perplexity': 5,
143+
},
144+
),
145+
StrategySpec(
146+
name='delaunay',
147+
kind='chunking',
148+
task_type='classification',
149+
strategy_params={
150+
'n_partitions': 10,
151+
'n_clusters': 2,
152+
'emptiness_threshold': 0.1,
153+
'dim_reduction_method': 'pca',
154+
'dim_reduction_target': 2,
155+
},
156+
),
157+
StrategySpec(
158+
name='hdbscan',
159+
kind='chunking',
160+
task_type='classification',
161+
strategy_params={
162+
'min_cluster_size': 5,
163+
'one_cluster': True,
164+
'prob_threshold': 0.5,
165+
'all_points': True,
166+
},
167+
),
168+
StrategySpec(
169+
name='voronoi',
170+
kind='chunking',
171+
task_type='classification',
172+
strategy_params={
173+
'n_partitions': 10,
174+
'emptiness_threshold': 0.1,
175+
},
176+
),
177+
StrategySpec(
178+
name='spectral_leverage',
179+
kind='subset',
180+
task_type='classification',
181+
strategy_params={},
182+
),
183+
StrategySpec(
184+
name='tensor_energy',
185+
kind='subset',
186+
task_type='classification',
187+
strategy_params={},
188+
),
189+
StrategySpec(
190+
name='kernel',
191+
kind='subset',
192+
task_type='classification',
193+
strategy_params={},
194+
),
195+
]
196+
197+
198+
def _sampling_zoo_available() -> bool:
199+
try:
200+
SamplingZooProvider().load_factory()
201+
return True
202+
except ModuleNotFoundError:
203+
return False
204+
205+
206+
@pytest.fixture(scope='session')
207+
def sampling_zoo_available():
208+
if not _sampling_zoo_available():
209+
pytest.skip('Sampling Zoo dependency is not available.')
210+
211+
212+
@pytest.fixture(scope='session')
213+
def classification_train_data():
214+
train_data, _, _ = get_dataset('classification', n_samples=10000, n_features=6, iris_dataset=False)
215+
return train_data
216+
217+
218+
@pytest.fixture(scope='session')
219+
def regression_train_data():
220+
train_data, _, _ = get_dataset('regression', n_samples=10000, n_features=6, iris_dataset=False)
221+
return train_data
222+
223+
224+
def _build_sampling_config(spec: StrategySpec) -> Dict[str, object]:
225+
config: Dict[str, object] = {
226+
'strategy_kind': spec.kind,
227+
'provider': 'sampling_zoo',
228+
'strategy': spec.name,
229+
'strategy_params': spec.strategy_params,
230+
}
231+
if spec.kind == 'subset':
232+
config.update({
233+
'candidate_ratios': [0.5],
234+
'delta_metric_threshold': 1.0,
235+
})
236+
return config
237+
238+
38239
def test_fit_with_sampling_config_none_preserves_default_behavior():
39240
train_data, _, _ = get_dataset('classification', n_samples=120, n_features=6, iris_dataset=False)
40241

@@ -200,6 +401,41 @@ def test_fail_fast_for_multimodal_input_with_sampling_stage():
200401
model.fit(features=data, target=target)
201402

202403

404+
@pytest.mark.integration
405+
@pytest.mark.slow
406+
@pytest.mark.parametrize('spec', SAMPLING_STRATEGY_SPECS, ids=lambda spec: f'{spec.kind}:{spec.name}')
407+
def test_sampling_stage_runs_all_strategies(spec: StrategySpec,
408+
sampling_zoo_available,
409+
classification_train_data,
410+
regression_train_data):
411+
if spec.skip_reason:
412+
pytest.skip(spec.skip_reason)
413+
414+
train_data = classification_train_data if spec.task_type == 'classification' else regression_train_data
415+
sampling_config = _build_sampling_config(spec)
416+
417+
model = Fedot(problem=spec.task_type,
418+
timeout=0.2,
419+
preset='fast_train',
420+
max_depth=1,
421+
max_arity=2,
422+
sampling_config=sampling_config)
423+
424+
try:
425+
pipeline = model.fit(features=train_data)
426+
except (ModuleNotFoundError, ImportError) as exc:
427+
pytest.skip(str(exc))
428+
429+
assert pipeline is not None
430+
assert model.sampling_stage_metadata is not None
431+
assert model.sampling_stage_metadata['status'] == 'applied'
432+
if spec.kind == 'chunking':
433+
assert isinstance(model.current_pipeline, PipelineEnsemble)
434+
assert isinstance(model.train_data, list)
435+
else:
436+
assert isinstance(model.current_pipeline, Pipeline)
437+
438+
203439
def test_timeout_restored_after_sampling_stage_real_path(monkeypatch):
204440
train_data, _, _ = get_dataset('classification', n_samples=90, n_features=6, iris_dataset=False)
205441

@@ -235,4 +471,3 @@ def fake_execute(self, train_data_input):
235471
assert model.sampling_stage_metadata is not None
236472
assert model.sampling_stage_metadata['status'] == 'applied'
237473
assert model.params.timeout == pytest.approx(0.2)
238-

tests/api/api_utils/test_api_params_repository.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ def test_api_params_repository_preserves_valid_sampling_config():
1616
repository = ApiParamsRepository(TaskTypesEnum.classification)
1717

1818
result = repository.check_and_set_default_params({
19-
'sampling_config': {'strategy': 'random', 'candidate_ratios': [0.2, 0.5]},
19+
'sampling_config': {'strategy_kind': 'subset', 'strategy': 'random', 'candidate_ratios': [0.2, 0.5]},
2020
})
2121

22+
assert result['sampling_config']['strategy_kind'] == 'subset'
2223
assert result['sampling_config']['strategy'] == 'random'
2324
assert tuple(result['sampling_config']['candidate_ratios']) == (0.2, 0.5)
2425

0 commit comments

Comments
 (0)