1- import datetime
2- from dataclasses import asdict
1+ import datetime
32from typing import Sequence
43
54from golem .core .optimisers .genetic .operators .inheritance import GeneticSchemeTypesEnum
65from golem .core .optimisers .genetic .operators .mutation import MutationTypesEnum
76
7+ from fedot .api .api_utils .api_params_repository_rules import apply_default_params , build_default_api_params
88from fedot .api .sampling_stage .config import validate_sampling_config
99from fedot .core .composer .gp_composer .specific_operators import parameter_change_mutation , add_resample_mutation
10- from fedot .core .constants import AUTO_PRESET_NAME
1110from fedot .core .repository .tasks import TaskTypesEnum
1211from fedot .core .utils import default_fedot_data_dir
1312
@@ -33,68 +32,16 @@ def __init__(self, task_type: TaskTypesEnum):
3332 @staticmethod
3433 def default_params_for_task (task_type : TaskTypesEnum ) -> dict :
3534 """ Returns a dict with default parameters"""
36- if task_type in [TaskTypesEnum .classification , TaskTypesEnum .regression ]:
37- cv_folds = 5
38-
39- elif task_type == TaskTypesEnum .ts_forecasting :
40- cv_folds = 3
41-
42- # Dict with allowed keyword attributes for Api and their default values. If None - default value set
43- # in dataclasses ``PipelineComposerRequirements``, ``GPAlgorithmParameters``, ``GraphGenerationParams``
44- # will be used.
45- default_param_values_dict = dict (
46- parallelization_mode = 'populational' ,
47- show_progress = True ,
48- max_depth = 6 ,
49- max_arity = 3 ,
50- pop_size = 20 ,
51- num_of_generations = None ,
52- keep_n_best = 1 ,
53- available_operations = None ,
54- metric = None ,
55- cv_folds = cv_folds ,
56- genetic_scheme = None ,
57- early_stopping_iterations = None ,
58- early_stopping_timeout = 10 ,
59- optimizer = None ,
60- collect_intermediate_metric = False ,
61- max_pipeline_fit_time = None ,
62- initial_assumption = None ,
63- preset = AUTO_PRESET_NAME ,
64- use_operations_cache = True ,
65- use_preprocessing_cache = True ,
66- use_predictions_cache = True ,
67- use_stats = False ,
68- use_input_preprocessing = True ,
69- use_auto_preprocessing = False ,
70- use_meta_rules = False ,
71- cache_dir = default_fedot_data_dir (),
72- keep_history = True ,
73- history_dir = default_fedot_data_dir (),
74- with_tuning = True ,
75- seed = None ,
76- sampling_config = None ,
77- )
78- return default_param_values_dict
35+ return build_default_api_params (task_type , default_fedot_data_dir ())
7936
8037 def check_and_set_default_params (self , params : dict ) -> dict :
8138 """ Sets default values for parameters which were not set by the user
8239 and raises KeyError for invalid parameter keys"""
83- allowed_keys = self .default_params .keys ()
84- invalid_keys = params .keys () - allowed_keys
85- if invalid_keys :
86- raise KeyError (f"Invalid key parameters { invalid_keys } " )
87-
88- if 'sampling_config' in params :
89- validated_sampling_config = validate_sampling_config (params ['sampling_config' ])
90- params ['sampling_config' ] = asdict (validated_sampling_config ) if validated_sampling_config else None
91-
92- missing_params = self .default_params .keys () - params .keys ()
93- for k in missing_params :
94- if (v := self .default_params [k ]) is not None :
95- params [k ] = v
96-
97- return params
40+ return apply_default_params (
41+ params = params ,
42+ default_params = self .default_params ,
43+ sampling_validator = validate_sampling_config ,
44+ )
9845
9946 @staticmethod
10047 def get_params_for_composer_requirements (params : dict ) -> dict :
0 commit comments