1212from autointent import Context , Dataset
1313from autointent .configs import EmbedderConfig , InferenceNodeConfig , LoggingConfig , VectorIndexConfig
1414from autointent .custom_types import NodeType
15+ from autointent .metrics import PREDICTION_METRICS_MULTILABEL
1516from autointent .nodes import InferenceNode , NodeOptimizer
1617from autointent .utils import load_default_search_space , load_search_space
1718
@@ -103,7 +104,7 @@ def _is_inference(self) -> bool:
103104 """
104105 return isinstance (self .nodes [NodeType .scoring ], InferenceNode )
105106
106- def fit (self , dataset : Dataset , force_multilabel : bool = False , init_for_inference : bool = True ) -> Context :
107+ def fit (self , dataset : Dataset , force_multilabel : bool = False ) -> Context :
107108 """
108109 Optimize the pipeline from dataset.
109110
@@ -122,15 +123,20 @@ def fit(self, dataset: Dataset, force_multilabel: bool = False, init_for_inferen
122123
123124 self ._fit (context )
124125
125- if init_for_inference :
126- if context .is_ram_to_clear ():
127- nodes_configs = context .optimization_info .get_inference_nodes_config ()
128- nodes_list = [InferenceNode .from_config (cfg ) for cfg in nodes_configs ]
129- else :
130- modules_dict = context .optimization_info .get_best_modules ()
131- nodes_list = [InferenceNode (module , node_type ) for node_type , module in modules_dict .items ()]
126+ if context .is_ram_to_clear ():
127+ nodes_configs = context .optimization_info .get_inference_nodes_config ()
128+ nodes_list = [InferenceNode .from_config (cfg ) for cfg in nodes_configs ]
129+ else :
130+ modules_dict = context .optimization_info .get_best_modules ()
131+ nodes_list = [InferenceNode (module , node_type ) for node_type , module in modules_dict .items ()]
132+
133+ self .nodes = {node .node_type : node for node in nodes_list }
132134
133- self .nodes = {node .node_type : node for node in nodes_list }
135+ predictions = self .predict (context .data_handler .test_utterances ())
136+ for metric_name , metric in PREDICTION_METRICS_MULTILABEL .items ():
137+ context .optimization_info .pipeline_metrics [metric_name ] = metric (
138+ context .data_handler .test_labels (), predictions ,
139+ )
134140
135141 return context
136142
0 commit comments