Skip to content

Commit 1e2a3c6

Browse files
committed
extract pipeline node parameter normalization rules
1 parent 0ef2232 commit 1e2a3c6

File tree

4 files changed

+101
-13
lines changed

4 files changed

+101
-13
lines changed

fedot/core/pipelines/node.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
from fedot.core.operations.factory import OperationFactory
1616
from fedot.core.operations.operation import Operation
1717
from fedot.core.operations.operation_parameters import OperationParameters
18+
from fedot.core.pipelines.pipeline_node_rules import (
19+
merge_node_parameters,
20+
normalize_node_parameters,
21+
should_update_node_parameters,
22+
)
1823
from fedot.core.repository.operation_types_repository import OperationTypesRepository
1924
from fedot.core.utils import DEFAULT_PARAMS_STUB, NESTED_PARAMS_LABEL
2025

@@ -124,7 +129,7 @@ def update_params(self):
124129
"""Updates :attr:`custom_params` with changed parameters"""
125130
new_params = self.fitted_operation.get_params()
126131
changed_parameters = new_params.changed_parameters
127-
updated_parameters = {**self.parameters, **changed_parameters}
132+
updated_parameters = merge_node_parameters(self.parameters, changed_parameters)
128133
self.parameters = updated_parameters
129134

130135
@property
@@ -218,9 +223,7 @@ def fit(self,
218223
descriptive_id=self.descriptive_id)
219224

220225
# Update parameters after operation fitting (they can be corrected)
221-
not_atomized_operation = 'atomized' not in self.operation.operation_type
222-
223-
if not_atomized_operation and 'correct_params' in self.operation.metadata.tags:
226+
if should_update_node_parameters(self.operation.operation_type, self.operation.metadata.tags):
224227
self.update_params()
225228
return operation_predict
226229

@@ -357,15 +360,9 @@ def parameters(self, params: dict):
357360
Args:
358361
params: new parameters to be placed instead of existing
359362
"""
360-
if params is not None:
361-
# The check for "default_params" is needed for backward compatibility.
362-
if params == DEFAULT_PARAMS_STUB:
363-
params = {}
364-
# take nested params if they appeared (mostly used for tuning)
365-
if NESTED_PARAMS_LABEL in params:
366-
params = params[NESTED_PARAMS_LABEL]
367-
self._parameters = OperationParameters.from_operation_type(self.operation.operation_type, **params)
368-
self.content['params'] = self._parameters.to_dict()
363+
normalized_params = normalize_node_parameters(params, DEFAULT_PARAMS_STUB, NESTED_PARAMS_LABEL)
364+
self._parameters = OperationParameters.from_operation_type(self.operation.operation_type, **normalized_params)
365+
self.content['params'] = self._parameters.to_dict()
369366

370367
def __str__(self) -> str:
371368
"""Returns ``str`` representation of the node
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Dict, Iterable, Optional
2+
3+
4+
5+
def normalize_node_parameters(params: Optional[dict], default_params_stub, nested_params_label: str) -> Dict:
6+
if params is None:
7+
return {}
8+
if params == default_params_stub:
9+
return {}
10+
if nested_params_label in params:
11+
return dict(params[nested_params_label])
12+
return dict(params)
13+
14+
15+
16+
def merge_node_parameters(current_parameters: Optional[dict], changed_parameters: Optional[dict]) -> Dict:
17+
return {
18+
**dict(current_parameters or {}),
19+
**dict(changed_parameters or {}),
20+
}
21+
22+
23+
24+
def should_update_node_parameters(operation_type: str, operation_tags: Optional[Iterable[str]]) -> bool:
25+
if 'atomized' in operation_type:
26+
return False
27+
return 'correct_params' in set(operation_tags or ())

tests/core/pipelines/test_node.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from fedot.core.operations.operation_parameters import OperationParameters
2+
from fedot.core.pipelines.node import PipelineNode
3+
from fedot.core.utils import DEFAULT_PARAMS_STUB, NESTED_PARAMS_LABEL
4+
5+
6+
class _FittedOperationWithParams:
7+
def __init__(self, params):
8+
self._params = params
9+
10+
def get_params(self):
11+
return self._params
12+
13+
14+
15+
def test_pipeline_node_parameters_setter_normalizes_default_and_nested_params():
16+
default_node = PipelineNode(operation_type='ridge')
17+
nested_node = PipelineNode(operation_type='ridge')
18+
19+
default_node.parameters = DEFAULT_PARAMS_STUB
20+
nested_node.parameters = {NESTED_PARAMS_LABEL: {'alpha': 1.0}}
21+
22+
assert default_node.parameters == {}
23+
assert nested_node.parameters['alpha'] == 1.0
24+
25+
26+
27+
def test_pipeline_node_update_params_uses_typed_merge_rule():
28+
node = PipelineNode(operation_type='ridge')
29+
node.parameters = {'alpha': 1.0}
30+
fitted_params = OperationParameters(alpha=1.0)
31+
fitted_params.update(beta=2.0)
32+
node.fitted_operation = _FittedOperationWithParams(fitted_params)
33+
34+
node.update_params()
35+
36+
assert node.parameters['alpha'] == 1.0
37+
assert node.parameters['beta'] == 2.0
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from fedot.core.pipelines.pipeline_node_rules import (
2+
merge_node_parameters,
3+
normalize_node_parameters,
4+
should_update_node_parameters,
5+
)
6+
from fedot.core.utils import DEFAULT_PARAMS_STUB, NESTED_PARAMS_LABEL
7+
8+
9+
10+
def test_normalize_node_parameters_handles_default_stub_and_nested_params():
11+
assert normalize_node_parameters(DEFAULT_PARAMS_STUB, DEFAULT_PARAMS_STUB, NESTED_PARAMS_LABEL) == {}
12+
assert normalize_node_parameters(
13+
{NESTED_PARAMS_LABEL: {'alpha': 1.0}},
14+
DEFAULT_PARAMS_STUB,
15+
NESTED_PARAMS_LABEL,
16+
) == {'alpha': 1.0}
17+
assert normalize_node_parameters({'beta': 2.0}, DEFAULT_PARAMS_STUB, NESTED_PARAMS_LABEL) == {'beta': 2.0}
18+
19+
20+
21+
def test_merge_node_parameters_and_update_rule_are_explicit():
22+
merged = merge_node_parameters({'alpha': 1.0}, {'beta': 2.0})
23+
24+
assert merged == {'alpha': 1.0, 'beta': 2.0}
25+
assert should_update_node_parameters('ridge', ['correct_params']) is True
26+
assert should_update_node_parameters('atomized_operation', ['correct_params']) is False
27+
assert should_update_node_parameters('ridge', ['linear']) is False

0 commit comments

Comments
 (0)