Skip to content

Commit 5b5eabd

Browse files
committed
extract pure api params repository defaulting rules
1 parent bf09df6 commit 5b5eabd

File tree

1 file changed

+8
-61
lines changed

1 file changed

+8
-61
lines changed

fedot/api/api_utils/api_params_repository.py

Lines changed: 8 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
import datetime
2-
from dataclasses import asdict
1+
import datetime
32
from typing import Sequence
43

54
from golem.core.optimisers.genetic.operators.inheritance import GeneticSchemeTypesEnum
65
from golem.core.optimisers.genetic.operators.mutation import MutationTypesEnum
76

7+
from fedot.api.api_utils.api_params_repository_rules import apply_default_params, build_default_api_params
88
from fedot.api.sampling_stage.config import validate_sampling_config
99
from fedot.core.composer.gp_composer.specific_operators import parameter_change_mutation, add_resample_mutation
10-
from fedot.core.constants import AUTO_PRESET_NAME
1110
from fedot.core.repository.tasks import TaskTypesEnum
1211
from fedot.core.utils import default_fedot_data_dir
1312

@@ -33,68 +32,16 @@ def __init__(self, task_type: TaskTypesEnum):
3332
@staticmethod
3433
def default_params_for_task(task_type: TaskTypesEnum) -> dict:
3534
""" Returns a dict with default parameters"""
36-
if task_type in [TaskTypesEnum.classification, TaskTypesEnum.regression]:
37-
cv_folds = 5
38-
39-
elif task_type == TaskTypesEnum.ts_forecasting:
40-
cv_folds = 3
41-
42-
# Dict with allowed keyword attributes for Api and their default values. If None - default value set
43-
# in dataclasses ``PipelineComposerRequirements``, ``GPAlgorithmParameters``, ``GraphGenerationParams``
44-
# will be used.
45-
default_param_values_dict = dict(
46-
parallelization_mode='populational',
47-
show_progress=True,
48-
max_depth=6,
49-
max_arity=3,
50-
pop_size=20,
51-
num_of_generations=None,
52-
keep_n_best=1,
53-
available_operations=None,
54-
metric=None,
55-
cv_folds=cv_folds,
56-
genetic_scheme=None,
57-
early_stopping_iterations=None,
58-
early_stopping_timeout=10,
59-
optimizer=None,
60-
collect_intermediate_metric=False,
61-
max_pipeline_fit_time=None,
62-
initial_assumption=None,
63-
preset=AUTO_PRESET_NAME,
64-
use_operations_cache=True,
65-
use_preprocessing_cache=True,
66-
use_predictions_cache=True,
67-
use_stats=False,
68-
use_input_preprocessing=True,
69-
use_auto_preprocessing=False,
70-
use_meta_rules=False,
71-
cache_dir=default_fedot_data_dir(),
72-
keep_history=True,
73-
history_dir=default_fedot_data_dir(),
74-
with_tuning=True,
75-
seed=None,
76-
sampling_config=None,
77-
)
78-
return default_param_values_dict
35+
return build_default_api_params(task_type, default_fedot_data_dir())
7936

8037
def check_and_set_default_params(self, params: dict) -> dict:
8138
""" Sets default values for parameters which were not set by the user
8239
and raises KeyError for invalid parameter keys"""
83-
allowed_keys = self.default_params.keys()
84-
invalid_keys = params.keys() - allowed_keys
85-
if invalid_keys:
86-
raise KeyError(f"Invalid key parameters {invalid_keys}")
87-
88-
if 'sampling_config' in params:
89-
validated_sampling_config = validate_sampling_config(params['sampling_config'])
90-
params['sampling_config'] = asdict(validated_sampling_config) if validated_sampling_config else None
91-
92-
missing_params = self.default_params.keys() - params.keys()
93-
for k in missing_params:
94-
if (v := self.default_params[k]) is not None:
95-
params[k] = v
96-
97-
return params
40+
return apply_default_params(
41+
params=params,
42+
default_params=self.default_params,
43+
sampling_validator=validate_sampling_config,
44+
)
9845

9946
@staticmethod
10047
def get_params_for_composer_requirements(params: dict) -> dict:

0 commit comments

Comments
 (0)