1919class NodeOptimizer :
2020 """Node optimizer class."""
2121
22- def __init__ (self , node_type : NodeType , search_space : list [dict [str , Any ]], metric : str ) -> None :
22+ def __init__ (
23+ self ,
24+ node_type : NodeType ,
25+ search_space : list [dict [str , Any ]],
26+ target_metric : str ,
27+ metrics : list [str ] | None = None ,
28+ ) -> None :
2329 """
2430 Initialize the node optimizer.
2531
@@ -29,7 +35,12 @@ def __init__(self, node_type: NodeType, search_space: list[dict[str, Any]], metr
2935 """
3036 self .node_type = node_type
3137 self .node_info = NODES_INFO [node_type ]
32- self .metric_name = metric
38+ self .decision_metric_name = target_metric
39+
40+ self .metrics = metrics if metrics is not None else []
41+ if self .decision_metric_name not in self .metrics :
42+ self .metrics .append (self .decision_metric_name )
43+
3344 self .modules_search_spaces = search_space # TODO search space validation
3445 self ._logger = logging .getLogger (__name__ ) # TODO solve duplicate logging messages problem
3546
@@ -61,14 +72,10 @@ def fit(self, context: Context) -> None:
6172 self .module_fit (module , context )
6273
6374 self ._logger .debug ("scoring %s module..." , module_name )
64- metrics = module .score (context , "validation" )
65- metric_value = metrics [self .metric_name ]
66-
67- # some metrics can produce error. When main metric produces error raise it.
68- if isinstance (metric_value , str ):
69- raise Exception (metric_value ) # noqa: TRY004, TRY002
75+ metrics_score = module .score (context , "validation" , self .metrics )
76+ metric_value = metrics_score [self .decision_metric_name ]
7077
71- context .callback_handler .log_metrics (metrics )
78+ context .callback_handler .log_metrics (metrics_score )
7279 context .callback_handler .end_module ()
7380
7481 dump_dir = context .get_dump_dir ()
@@ -84,7 +91,7 @@ def fit(self, context: Context) -> None:
8491 module_name ,
8592 module_kwargs ,
8693 metric_value ,
87- self .metric_name ,
94+ self .decision_metric_name ,
8895 module .get_assets (), # retriever name / scores / predictions
8996 module_dump_dir ,
9097 module = module if not context .is_ram_to_clear () else None ,
0 commit comments