1- """Pipeline optimizer module.
2-
3- This module defines the Pipeline class, which is responsible for optimizing and managing a pipeline of inference nodes.
4- It provides functionality for configuration, optimization, validation, and inference.
5- """
1+ """Pipeline optimizer."""
62
73import json
84import logging
139import yaml
1410from typing_extensions import assert_never
1511
16- from autointent import Context , Dataset
12+ from autointent import Context , Dataset , OptimizationConfig
1713from autointent .configs import (
1814 CrossEncoderConfig ,
1915 DataConfig ,
2016 EmbedderConfig ,
17+ InferenceNodeConfig ,
2118 LoggingConfig ,
2219)
2320from autointent .custom_types import (
2421 ListOfGenericLabels ,
2522 NodeType ,
2623 SamplerType ,
24+ SearchSpacePresets ,
2725 SearchSpaceValidationMode ,
2826)
27+ from autointent .metrics import DECISION_METRICS
2928from autointent .nodes import InferenceNode , NodeOptimizer
29+ from autointent .utils import load_preset , load_search_space
3030
3131from ._schemas import InferencePipelineOutput , InferencePipelineUtteranceOutput
3232
3535
3636
3737class Pipeline :
38- """Pipeline optimizer for managing and optimizing inference nodes.
39-
40- This class is responsible for initializing and optimizing a sequence of nodes that perform inference tasks.
41- It supports loading configurations, validating data, and making predictions.
42-
43- Attributes:
44- nodes: Dictionary of node types mapped to their respective objects.
45- sampler: Sampling method used for optimization.
46- seed: Random seed for reproducibility.
47- """
38+ """Pipeline optimizer class."""
4839
4940 def __init__ (
5041 self ,
@@ -55,12 +46,9 @@ def __init__(
5546 """Initialize the pipeline optimizer.
5647
5748 Args:
58- nodes: List of nodes to be optimized or used for inference.
59- sampler: Sampling strategy for optimization. Defaults to "brute".
60- seed: Random seed for reproducibility. Defaults to 42.
61-
62- Raises:
63- ValueError: If the provided sampler type is invalid.
49+ nodes: List of nodes.
50+ sampler: Sampler type.
51+ seed: Random seed.
6452 """
6553 self ._logger = logging .getLogger (__name__ )
6654 self .nodes = {node .node_type : node for node in nodes }
@@ -80,7 +68,7 @@ def __init__(
8068 assert_never (nodes )
8169
8270 def set_config (self , config : LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig ) -> None :
83- """Set configuration for the pipeline.
71+ """Set the configuration for the pipeline.
8472
8573 Args:
8674 config: Configuration object.
@@ -96,14 +84,103 @@ def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig
9684 else :
9785 assert_never (config )
9886
99- def fit (self , dataset : Dataset ) -> Context :
100- """Optimize the pipeline using a dataset.
87+ @classmethod
88+ def from_search_space (cls , search_space : list [dict [str , Any ]] | Path | str , seed : int = 42 ) -> "Pipeline" :
89+ """Search space to pipeline optimizer.
90+
91+ Args:
92+ search_space: Search space.
93+ seed: Random seed.
94+
95+ Returns:
96+ Pipeline optimizer.
97+ """
98+ if not isinstance (search_space , list ):
99+ search_space = load_search_space (search_space )
100+ nodes = [NodeOptimizer (** node ) for node in search_space ]
101+ return cls (nodes = nodes , seed = seed )
102+
103+ @classmethod
104+ def from_preset (cls , name : SearchSpacePresets , seed : int = 42 ) -> "Pipeline" :
105+ optimization_config = load_preset (name )
106+ config = OptimizationConfig (seed = seed , ** optimization_config )
107+ return cls .from_optimization_config (config = config )
108+
109+ @classmethod
110+ def from_optimization_config (cls , config : dict [str , Any ] | Path | str | OptimizationConfig ) -> "Pipeline" :
111+ """Create pipeline optimizer from optimization config.
112+
113+ :param config: Optimization config
114+ :return:
115+ """
116+ if isinstance (config , OptimizationConfig ):
117+ optimization_config = config
118+ else :
119+ if isinstance (config , dict ):
120+ dict_params = config
121+ else :
122+ with Path (config ).open () as file :
123+ dict_params = yaml .safe_load (file )
124+ optimization_config = OptimizationConfig (** dict_params )
125+
126+ pipeline = cls (
127+ [NodeOptimizer (** node .model_dump ()) for node in optimization_config .search_space ],
128+ optimization_config .sampler ,
129+ optimization_config .seed ,
130+ )
131+ pipeline .set_config (optimization_config .logging_config )
132+ pipeline .set_config (optimization_config .data_config )
133+ pipeline .set_config (optimization_config .embedder_config )
134+ pipeline .set_config (optimization_config .cross_encoder_config )
135+ return pipeline
136+
137+ def _fit (self , context : Context , sampler : SamplerType ) -> None :
138+ """Optimize the pipeline.
139+
140+ Args:
141+ context: Context object.
142+ sampler: Sampler type.
143+ """
144+ self .context = context
145+ self ._logger .info ("starting pipeline optimization..." )
146+ self .context .callback_handler .start_run (
147+ run_name = self .context .logging_config .get_run_name (),
148+ dirpath = self .context .logging_config .dirpath ,
149+ )
150+ for node_type in NodeType :
151+ node_optimizer = self .nodes .get (node_type , None )
152+ if node_optimizer is not None :
153+ node_optimizer .fit (context , sampler ) # type: ignore[union-attr]
154+ self .context .callback_handler .end_run ()
155+
156+ def _is_inference (self ) -> bool :
157+ """Check the mode in which pipeline is.
158+
159+ Returns:
160+ True if pipeline is in inference mode, False otherwise.
161+ """
162+ return isinstance (self .nodes [NodeType .scoring ], InferenceNode )
163+
164+ def fit (
165+ self ,
166+ dataset : Dataset ,
167+ refit_after : bool = False ,
168+ sampler : SamplerType | None = None ,
169+ incompatible_search_space : SearchSpaceValidationMode = "filter" ,
170+ ) -> Context :
171+ """Optimize the pipeline from dataset.
101172
102173 Args:
103- dataset: The dataset used for optimization.
174+ dataset: Dataset for optimization.
175+ refit_after: Whether to refit after optimization.
176+ sampler: Sampler type to use.
177+ incompatible_search_space: How to handle incompatible search space.
104178
105179 Returns:
106- Context: The resulting context after optimization.
180+ Context object.
181+
182+ Raises:
183+ RuntimeError: If pipeline is in inference mode.
107184 """
108185 if self ._is_inference ():
109186 msg = "Pipeline in inference mode cannot be fitted"
@@ -115,53 +192,103 @@ def fit(self, dataset: Dataset) -> Context:
115192 context .configure_transformer (self .embedder_config )
116193 context .configure_transformer (self .cross_encoder_config )
117194
118- self ._fit (context , self .sampler )
195+ self .validate_modules (dataset , mode = incompatible_search_space )
196+
197+ test_utterances = context .data_handler .test_utterances ()
198+ if test_utterances is None :
199+ self ._logger .warning (
200+ "Test data is not provided. Final test metrics won't be calculated after pipeline optimization."
201+ )
202+
203+ if sampler is None :
204+ sampler = self .sampler
205+
206+ self ._fit (context , sampler )
207+
208+ if context .is_ram_to_clear ():
209+ nodes_configs = context .optimization_info .get_inference_nodes_config ()
210+ nodes_list = [InferenceNode .from_config (cfg ) for cfg in nodes_configs ]
211+ else :
212+ modules_dict = context .optimization_info .get_best_modules ()
213+ nodes_list = [InferenceNode (module , node_type ) for node_type , module in modules_dict .items ()]
214+
215+ self .nodes = {node .node_type : node for node in nodes_list }
216+
217+ if refit_after :
218+ # TODO reflect this refitting in dumped version of pipeline
219+ self ._refit (context )
220+
221+ if test_utterances is not None :
222+ predictions = self .predict (test_utterances )
223+ for metric_name , metric in DECISION_METRICS .items ():
224+ context .optimization_info .pipeline_metrics [metric_name ] = metric (
225+ context .data_handler .test_labels (),
226+ predictions ,
227+ )
228+ context .callback_handler .log_final_metrics (context .optimization_info .dump_evaluation_results ())
229+
119230 return context
120231
121232 def validate_modules (self , dataset : Dataset , mode : SearchSpaceValidationMode ) -> None :
122- """Validate nodes against a dataset.
233+ """Validate modules with dataset.
123234
124235 Args:
125- dataset: Dataset used for validation.
236+ dataset: Dataset for validation.
126237 mode: Validation mode.
127238 """
128239 for node in self .nodes .values ():
129240 if isinstance (node , NodeOptimizer ):
130241 node .validate_nodes_with_dataset (dataset , mode )
131242
132- def _is_inference (self ) -> bool :
133- """Check whether the pipeline is in inference mode.
243+ @classmethod
244+ def from_dict_config (cls , nodes_configs : list [dict [str , Any ]]) -> "Pipeline" :
245+ """Create inference pipeline from dictionary config.
246+
247+ Args:
248+ nodes_configs: list of config for nodes
134249
135250 Returns:
136- True if pipeline is in inference mode, otherwise False.
251+ Inference pipeline
137252 """
138- return isinstance (self .nodes [NodeType .scoring ], InferenceNode )
253+ return cls .from_config ([InferenceNodeConfig (** cfg ) for cfg in nodes_configs ])
254+
255+ @classmethod
256+ def from_config (cls , nodes_configs : list [InferenceNodeConfig ]) -> "Pipeline" :
257+ """Create inference pipeline from config.
258+
259+ Args:
260+ nodes_configs: list of config for nodes
261+
262+ Returns:
263+ Inference pipeline
264+ """
265+ nodes = [InferenceNode .from_config (cfg ) for cfg in nodes_configs ]
266+ return cls (nodes )
139267
140268 @classmethod
141269 def load (cls , path : str | Path ) -> "Pipeline" :
142- """Load a pipeline from a given directory.
270+ """Load pipeline in inference mode.
271+
272+ This method loads fitted modules and tuned hyperparameters.
143273
144274 Args:
145- path: Path to the directory containing the pipeline configuration.
275+ path: Path to load
146276
147277 Returns:
148- Loaded pipeline instance.
278+ Inference pipeline
149279 """
150280 with (Path (path ) / "inference_config.yaml" ).open () as file :
151281 inference_dict_config = yaml .safe_load (file )
152282 return cls .from_dict_config (inference_dict_config ["nodes_configs" ])
153283
154284 def predict (self , utterances : list [str ]) -> ListOfGenericLabels :
155- """Predict labels for a list of utterances.
285+ """Predict the labels for the utterances.
156286
157287 Args:
158- utterances: List of utterances to predict labels for.
288+ utterances: list of utterances
159289
160290 Returns:
161- ListOfGenericLabels: Predicted labels for the utterances.
162-
163- Raises:
164- RuntimeError: If the pipeline is not in inference mode.
291+ list of predicted labels
165292 """
166293 if not self ._is_inference ():
167294 msg = "Pipeline in optimization mode cannot perform inference"
@@ -177,7 +304,10 @@ def _refit(self, context: Context) -> None:
177304 """Fit pipeline of already selected modules with all train data.
178305
179306 Args:
180- context: context object to take data from
307+ context: Context object.
308+
309+ Raises:
310+ RuntimeError: If pipeline is in optimization mode.
181311 """
182312 if not self ._is_inference ():
183313 msg = "Pipeline in optimization mode cannot perform inference"
@@ -198,12 +328,8 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
198328
199329 Args:
200330 utterances: list of utterances
201-
202331 Returns:
203- InferencePipelineOutput: prediction output
204-
205- Raises:
206- RuntimeError: If the pipeline is not in inference mode.
332+ Inference pipeline output
207333 """
208334 if not self ._is_inference ():
209335 msg = "Pipeline in optimization mode cannot perform inference"
@@ -242,11 +368,11 @@ def make_report(logs: dict[str, Any], nodes: list[NodeType]) -> str:
242368 """Generate a report from optimization logs.
243369
244370 Args:
245- logs: Dictionary containing optimization logs .
371+ logs: Logs dictionary .
246372 nodes: List of node types.
247373
248374 Returns:
249- Formatted report string .
375+ String report.
250376 """
251377 ids = [np .argmax (logs ["metrics" ][node ]) for node in nodes ]
252378 configs = []
0 commit comments