Skip to content

Commit cd01cdc

Browse files
authored
Add final pipeline metrics to logs (#83)
Add final pipeline metrics to logs
1 parent 0c1a8f9 commit cd01cdc

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from autointent import Context, Dataset
1313
from autointent.configs import EmbedderConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
1414
from autointent.custom_types import NodeType
15+
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
1516
from autointent.nodes import InferenceNode, NodeOptimizer
1617
from autointent.utils import load_default_search_space, load_search_space
1718

@@ -103,7 +104,7 @@ def _is_inference(self) -> bool:
103104
"""
104105
return isinstance(self.nodes[NodeType.scoring], InferenceNode)
105106

106-
def fit(self, dataset: Dataset, force_multilabel: bool = False, init_for_inference: bool = True) -> Context:
107+
def fit(self, dataset: Dataset, force_multilabel: bool = False) -> Context:
107108
"""
108109
Optimize the pipeline from dataset.
109110
@@ -122,15 +123,20 @@ def fit(self, dataset: Dataset, force_multilabel: bool = False, init_for_inferen
122123

123124
self._fit(context)
124125

125-
if init_for_inference:
126-
if context.is_ram_to_clear():
127-
nodes_configs = context.optimization_info.get_inference_nodes_config()
128-
nodes_list = [InferenceNode.from_config(cfg) for cfg in nodes_configs]
129-
else:
130-
modules_dict = context.optimization_info.get_best_modules()
131-
nodes_list = [InferenceNode(module, node_type) for node_type, module in modules_dict.items()]
126+
if context.is_ram_to_clear():
127+
nodes_configs = context.optimization_info.get_inference_nodes_config()
128+
nodes_list = [InferenceNode.from_config(cfg) for cfg in nodes_configs]
129+
else:
130+
modules_dict = context.optimization_info.get_best_modules()
131+
nodes_list = [InferenceNode(module, node_type) for node_type, module in modules_dict.items()]
132+
133+
self.nodes = {node.node_type: node for node in nodes_list}
132134

133-
self.nodes = {node.node_type: node for node in nodes_list}
135+
predictions = self.predict(context.data_handler.test_utterances())
136+
for metric_name, metric in PREDICTION_METRICS_MULTILABEL.items():
137+
context.optimization_info.pipeline_metrics[metric_name] = metric(
138+
context.data_handler.test_labels(), predictions,
139+
)
134140

135141
return context
136142

autointent/context/optimization_info/_optimization_info.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(self) -> None:
6565
self.trials = Trials()
6666
self._trials_best_ids = TrialsIds()
6767
self.modules = ModulesList()
68+
self.pipeline_metrics: dict[str, float] = {}
6869

6970
def log_module_optimization(
7071
self,
@@ -196,6 +197,7 @@ def dump_evaluation_results(self) -> dict[str, Any]:
196197
"""
197198
node_wise_metrics = {node_type: self._get_metrics_values(node_type) for node_type in NodeType}
198199
return {
200+
"pipeline_metrics": self.pipeline_metrics,
199201
"metrics": node_wise_metrics,
200202
"configs": self.trials.model_dump(),
201203
}

tests/pipeline/test_optimization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_no_context_optimization(dataset, task_type):
4444
pipeline_optimizer.set_config(VectorIndexConfig(db_dir=Path(db_dir).resolve()))
4545
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu"))
4646

47-
context = pipeline_optimizer.fit(dataset, force_multilabel=(task_type == "multilabel"), init_for_inference=False)
47+
context = pipeline_optimizer.fit(dataset, force_multilabel=(task_type == "multilabel"))
4848
context.dump()
4949

5050

@@ -62,7 +62,7 @@ def test_save_db(dataset, task_type):
6262
pipeline_optimizer.set_config(VectorIndexConfig(db_dir=Path(db_dir).resolve(), save_db=True))
6363
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu"))
6464

65-
context = pipeline_optimizer.fit(dataset, force_multilabel=(task_type == "multilabel"), init_for_inference=False)
65+
context = pipeline_optimizer.fit(dataset, force_multilabel=(task_type == "multilabel"))
6666
context.dump()
6767

6868
assert os.listdir(db_dir)
@@ -82,7 +82,7 @@ def test_dump_modules(dataset, task_type):
8282
pipeline_optimizer.set_config(VectorIndexConfig(db_dir=Path(db_dir).resolve()))
8383
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu"))
8484

85-
context = pipeline_optimizer.fit(dataset, force_multilabel=(task_type == "multilabel"), init_for_inference=False)
85+
context = pipeline_optimizer.fit(dataset, force_multilabel=(task_type == "multilabel"))
8686
context.dump()
8787

8888
assert os.listdir(dump_dir)

0 commit comments

Comments
 (0)