Skip to content

Commit d0f18ba

Browse files
committed
extract cache and tuner setup rules from api composer
1 parent 4e2a1fe commit d0f18ba

File tree

4 files changed

+132
-22
lines changed

4 files changed

+132
-22
lines changed

fedot/api/api_utils/api_composer.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from golem.core.optimisers.opt_history_objects.opt_history import OptHistory
88
from golem.core.tuning.simultaneous import SimultaneousTuner
99

10+
from fedot.api.api_utils.api_composer_rules import build_cache_init_plan, build_tuner_plan
1011
from fedot.api.api_utils.api_run_planner import build_composer_execution_plan
1112
from fedot.api.api_utils.assumptions.assumptions_handler import AssumptionsHandler
1213
from fedot.api.api_utils.params import ApiParams
@@ -41,23 +42,23 @@ def __init__(self, api_params: ApiParams, metrics: Union[MetricIDType, Sequence[
4142
self.init_cache()
4243

4344
def init_cache(self):
44-
use_operations_cache = self.params.get('use_operations_cache')
45-
use_preprocessing_cache = self.params.get('use_preprocessing_cache')
46-
use_predictions_cache = self.params.get('use_predictions_cache')
47-
use_input_preprocessing = self.params.get('use_input_preprocessing')
48-
cache_dir = self.params.get('cache_dir')
49-
use_stats = self.params.get('use_stats')
50-
if use_operations_cache:
51-
self.operations_cache = OperationsCache(cache_dir=cache_dir, use_stats=use_stats)
52-
# in case of previously generated singleton cache
45+
cache_plan = build_cache_init_plan(
46+
use_operations_cache=self.params.get('use_operations_cache'),
47+
use_preprocessing_cache=self.params.get('use_preprocessing_cache'),
48+
use_predictions_cache=self.params.get('use_predictions_cache'),
49+
use_input_preprocessing=self.params.get('use_input_preprocessing'),
50+
cache_dir=self.params.get('cache_dir'),
51+
use_stats=self.params.get('use_stats'),
52+
)
53+
54+
if cache_plan.use_operations_cache:
55+
self.operations_cache = OperationsCache(cache_dir=cache_plan.cache_dir, use_stats=cache_plan.use_stats)
5356
self.operations_cache.reset()
54-
if use_input_preprocessing and use_preprocessing_cache:
55-
self.preprocessing_cache = PreprocessingCache(cache_dir=cache_dir, use_stats=use_stats)
56-
# in case of previously generated singleton cache
57+
if cache_plan.use_preprocessing_cache:
58+
self.preprocessing_cache = PreprocessingCache(cache_dir=cache_plan.cache_dir, use_stats=cache_plan.use_stats)
5759
self.preprocessing_cache.reset()
58-
if use_predictions_cache:
59-
self.predictions_cache = PredictionsCache(cache_dir=cache_dir, use_stats=use_stats)
60-
# in case of previously generated singleton cache
60+
if cache_plan.use_predictions_cache:
61+
self.predictions_cache = PredictionsCache(cache_dir=cache_plan.cache_dir, use_stats=cache_plan.use_stats)
6162
self.predictions_cache.reset()
6263

6364
def obtain_model(self, train_data: InputData) -> Tuple[Pipeline, Sequence[Pipeline], OptHistory]:
@@ -106,7 +107,6 @@ def obtain_model(self, train_data: InputData) -> Tuple[Pipeline, Sequence[Pipeli
106107
if gp_composer.history:
107108
adapter = self.params.graph_generation_params.adapter
108109
gp_composer.history.tuning_result = adapter.adapt(best_pipeline)
109-
# enforce memory cleaning
110110
gc.collect()
111111

112112
self.log.message('Model generation finished')
@@ -166,15 +166,13 @@ def compose_pipeline(self, train_data: InputData, initial_assumption: Sequence[P
166166
)
167167

168168
if execution_plan.should_compose:
169-
# Launch pipeline structure composition
170169
with self.timer.launch_composing():
171170
self.log.message('Pipeline composition started.')
172171
self.was_optimised = False
173172
best_pipelines = gp_composer.compose_pipeline(data=train_data)
174173
best_pipeline_candidates = gp_composer.best_models
175174
self.was_optimised = True
176175
else:
177-
# Use initial pipeline as final solution
178176
self.log.message(f'Timeout is too small for composing and is skipped '
179177
f'because fit_time is {self.timer.assumption_fit_spend_time.total_seconds()} sec.')
180178
best_pipelines = fitted_assumption
@@ -192,18 +190,23 @@ def tune_final_pipeline(self, train_data: InputData,
192190
""" Launch tuning procedure for obtained pipeline by composer """
193191
timeout_for_tuning = execution_plan.tuning_timeout_minutes if execution_plan else abs(
194192
self.timer.determine_resources_for_tuning()) / 60
193+
tuner_plan = build_tuner_plan(
194+
metrics=self.metrics,
195+
timeout_minutes=timeout_for_tuning,
196+
iterations=DEFAULT_TUNING_ITERATIONS_NUMBER,
197+
)
195198
tuner = (TunerBuilder(self.params.task)
196199
.with_tuner(SimultaneousTuner)
197-
.with_metric(self.metrics[0])
198-
.with_iterations(DEFAULT_TUNING_ITERATIONS_NUMBER)
199-
.with_timeout(datetime.timedelta(minutes=timeout_for_tuning))
200+
.with_metric(tuner_plan.metric)
201+
.with_iterations(tuner_plan.iterations)
202+
.with_timeout(datetime.timedelta(minutes=tuner_plan.timeout_minutes))
200203
.with_eval_time_constraint(self.params.composer_requirements.max_graph_fit_time)
201204
.with_requirements(self.params.composer_requirements)
202205
.build(train_data))
203206

204207
with self.timer.launch_tuning():
205208
self.was_tuned = False
206-
self.log.message(f'Hyperparameters tuning started with {round(timeout_for_tuning)} min. timeout')
209+
self.log.message(f'Hyperparameters tuning started with {round(tuner_plan.timeout_minutes)} min. timeout')
207210
tuned_pipeline = tuner.tune(pipeline_gp_composed)
208211
self.log.message('Hyperparameters tuning finished')
209212
self.was_tuned = tuner.was_tuned
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Sequence
3+
4+
5+
@dataclass(frozen=True)
6+
class CacheInitPlan:
7+
use_operations_cache: bool
8+
use_preprocessing_cache: bool
9+
use_predictions_cache: bool
10+
cache_dir: str | None
11+
use_stats: bool
12+
13+
14+
@dataclass(frozen=True)
15+
class TunerPlan:
16+
metric: Any
17+
iterations: int
18+
timeout_minutes: float
19+
20+
21+
def build_cache_init_plan(use_operations_cache: bool,
22+
use_preprocessing_cache: bool,
23+
use_predictions_cache: bool,
24+
use_input_preprocessing: bool,
25+
cache_dir,
26+
use_stats: bool) -> CacheInitPlan:
27+
return CacheInitPlan(
28+
use_operations_cache=bool(use_operations_cache),
29+
use_preprocessing_cache=bool(use_input_preprocessing and use_preprocessing_cache),
30+
use_predictions_cache=bool(use_predictions_cache),
31+
cache_dir=cache_dir,
32+
use_stats=bool(use_stats),
33+
)
34+
35+
36+
def build_tuner_plan(metrics: Sequence[Any], timeout_minutes: float, iterations: int) -> TunerPlan:
37+
return TunerPlan(
38+
metric=metrics[0],
39+
iterations=iterations,
40+
timeout_minutes=max(0.0, timeout_minutes),
41+
)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import fedot.api.api_utils.api_composer as composer_module
2+
from fedot.api.api_utils.api_composer import ApiComposer
3+
4+
5+
class _FakeCache:
6+
def __init__(self, cache_dir=None, use_stats=False):
7+
self.cache_dir = cache_dir
8+
self.use_stats = use_stats
9+
self.was_reset = False
10+
11+
def reset(self):
12+
self.was_reset = True
13+
14+
15+
class _FakeParams(dict):
16+
timeout = 1
17+
n_jobs = -1
18+
19+
20+
def test_api_composer_init_cache_uses_typed_cache_plan(monkeypatch):
21+
monkeypatch.setattr(composer_module, 'OperationsCache', _FakeCache)
22+
monkeypatch.setattr(composer_module, 'PreprocessingCache', _FakeCache)
23+
monkeypatch.setattr(composer_module, 'PredictionsCache', _FakeCache)
24+
25+
params = _FakeParams(
26+
use_operations_cache=True,
27+
use_preprocessing_cache=True,
28+
use_predictions_cache=True,
29+
use_input_preprocessing=False,
30+
cache_dir='cache_dir',
31+
use_stats=True,
32+
)
33+
34+
composer = ApiComposer(params, metrics=['f1'])
35+
36+
assert isinstance(composer.operations_cache, _FakeCache)
37+
assert composer.operations_cache.was_reset is True
38+
assert composer.preprocessing_cache is None
39+
assert isinstance(composer.predictions_cache, _FakeCache)
40+
assert composer.predictions_cache.was_reset is True
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from fedot.api.api_utils.api_composer_rules import build_cache_init_plan, build_tuner_plan
2+
3+
4+
def test_build_cache_init_plan_respects_input_preprocessing_boundary():
5+
plan = build_cache_init_plan(
6+
use_operations_cache=True,
7+
use_preprocessing_cache=True,
8+
use_predictions_cache=True,
9+
use_input_preprocessing=False,
10+
cache_dir='cache',
11+
use_stats=True,
12+
)
13+
14+
assert plan.use_operations_cache is True
15+
assert plan.use_preprocessing_cache is False
16+
assert plan.use_predictions_cache is True
17+
assert plan.cache_dir == 'cache'
18+
assert plan.use_stats is True
19+
20+
21+
def test_build_tuner_plan_is_deterministic_and_clamps_timeout():
22+
plan = build_tuner_plan(metrics=['f1', 'roc_auc'], timeout_minutes=-3, iterations=42)
23+
24+
assert plan.metric == 'f1'
25+
assert plan.iterations == 42
26+
assert plan.timeout_minutes == 0.0

0 commit comments

Comments
 (0)