Skip to content

Commit e624592

Browse files
committed
fix docs
1 parent e9f3427 commit e624592

File tree

4 files changed

+189
-62
lines changed

4 files changed

+189
-62
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 177 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
"""Pipeline optimizer module.
2-
3-
This module defines the Pipeline class, which is responsible for optimizing and managing a pipeline of inference nodes.
4-
It provides functionality for configuration, optimization, validation, and inference.
5-
"""
1+
"""Pipeline optimizer."""
62

73
import json
84
import logging
@@ -13,20 +9,24 @@
139
import yaml
1410
from typing_extensions import assert_never
1511

16-
from autointent import Context, Dataset
12+
from autointent import Context, Dataset, OptimizationConfig
1713
from autointent.configs import (
1814
CrossEncoderConfig,
1915
DataConfig,
2016
EmbedderConfig,
17+
InferenceNodeConfig,
2118
LoggingConfig,
2219
)
2320
from autointent.custom_types import (
2421
ListOfGenericLabels,
2522
NodeType,
2623
SamplerType,
24+
SearchSpacePresets,
2725
SearchSpaceValidationMode,
2826
)
27+
from autointent.metrics import DECISION_METRICS
2928
from autointent.nodes import InferenceNode, NodeOptimizer
29+
from autointent.utils import load_preset, load_search_space
3030

3131
from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput
3232

@@ -35,16 +35,7 @@
3535

3636

3737
class Pipeline:
38-
"""Pipeline optimizer for managing and optimizing inference nodes.
39-
40-
This class is responsible for initializing and optimizing a sequence of nodes that perform inference tasks.
41-
It supports loading configurations, validating data, and making predictions.
42-
43-
Attributes:
44-
nodes: Dictionary of node types mapped to their respective objects.
45-
sampler: Sampling method used for optimization.
46-
seed: Random seed for reproducibility.
47-
"""
38+
"""Pipeline optimizer class."""
4839

4940
def __init__(
5041
self,
@@ -55,12 +46,9 @@ def __init__(
5546
"""Initialize the pipeline optimizer.
5647
5748
Args:
58-
nodes: List of nodes to be optimized or used for inference.
59-
sampler: Sampling strategy for optimization. Defaults to "brute".
60-
seed: Random seed for reproducibility. Defaults to 42.
61-
62-
Raises:
63-
ValueError: If the provided sampler type is invalid.
49+
nodes: List of nodes.
50+
sampler: Sampler type.
51+
seed: Random seed.
6452
"""
6553
self._logger = logging.getLogger(__name__)
6654
self.nodes = {node.node_type: node for node in nodes}
@@ -80,7 +68,7 @@ def __init__(
8068
assert_never(nodes)
8169

8270
def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig) -> None:
83-
"""Set configuration for the pipeline.
71+
"""Set the configuration for the pipeline.
8472
8573
Args:
8674
config: Configuration object.
@@ -96,14 +84,103 @@ def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig
9684
else:
9785
assert_never(config)
9886

99-
def fit(self, dataset: Dataset) -> Context:
100-
"""Optimize the pipeline using a dataset.
87+
@classmethod
88+
def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int = 42) -> "Pipeline":
89+
"""Search space to pipeline optimizer.
90+
91+
Args:
92+
search_space: Search space.
93+
seed: Random seed.
94+
95+
Returns:
96+
Pipeline optimizer.
97+
"""
98+
if not isinstance(search_space, list):
99+
search_space = load_search_space(search_space)
100+
nodes = [NodeOptimizer(**node) for node in search_space]
101+
return cls(nodes=nodes, seed=seed)
102+
103+
@classmethod
104+
def from_preset(cls, name: SearchSpacePresets, seed: int = 42) -> "Pipeline":
105+
optimization_config = load_preset(name)
106+
config = OptimizationConfig(seed=seed, **optimization_config)
107+
return cls.from_optimization_config(config=config)
108+
109+
@classmethod
110+
def from_optimization_config(cls, config: dict[str, Any] | Path | str | OptimizationConfig) -> "Pipeline":
111+
"""Create pipeline optimizer from optimization config.
112+
113+
:param config: Optimization config
114+
:return:
115+
"""
116+
if isinstance(config, OptimizationConfig):
117+
optimization_config = config
118+
else:
119+
if isinstance(config, dict):
120+
dict_params = config
121+
else:
122+
with Path(config).open() as file:
123+
dict_params = yaml.safe_load(file)
124+
optimization_config = OptimizationConfig(**dict_params)
125+
126+
pipeline = cls(
127+
[NodeOptimizer(**node.model_dump()) for node in optimization_config.search_space],
128+
optimization_config.sampler,
129+
optimization_config.seed,
130+
)
131+
pipeline.set_config(optimization_config.logging_config)
132+
pipeline.set_config(optimization_config.data_config)
133+
pipeline.set_config(optimization_config.embedder_config)
134+
pipeline.set_config(optimization_config.cross_encoder_config)
135+
return pipeline
136+
137+
def _fit(self, context: Context, sampler: SamplerType) -> None:
138+
"""Optimize the pipeline.
139+
140+
Args:
141+
context: Context object.
142+
sampler: Sampler type.
143+
"""
144+
self.context = context
145+
self._logger.info("starting pipeline optimization...")
146+
self.context.callback_handler.start_run(
147+
run_name=self.context.logging_config.get_run_name(),
148+
dirpath=self.context.logging_config.dirpath,
149+
)
150+
for node_type in NodeType:
151+
node_optimizer = self.nodes.get(node_type, None)
152+
if node_optimizer is not None:
153+
node_optimizer.fit(context, sampler) # type: ignore[union-attr]
154+
self.context.callback_handler.end_run()
155+
156+
def _is_inference(self) -> bool:
157+
"""Check the mode in which pipeline is.
158+
159+
Returns:
160+
True if pipeline is in inference mode, False otherwise.
161+
"""
162+
return isinstance(self.nodes[NodeType.scoring], InferenceNode)
163+
164+
def fit(
165+
self,
166+
dataset: Dataset,
167+
refit_after: bool = False,
168+
sampler: SamplerType | None = None,
169+
incompatible_search_space: SearchSpaceValidationMode = "filter",
170+
) -> Context:
171+
"""Optimize the pipeline from dataset.
101172
102173
Args:
103-
dataset: The dataset used for optimization.
174+
dataset: Dataset for optimization.
175+
refit_after: Whether to refit after optimization.
176+
sampler: Sampler type to use.
177+
incompatible_search_space: How to handle incompatible search space.
104178
105179
Returns:
106-
Context: The resulting context after optimization.
180+
Context object.
181+
182+
Raises:
183+
RuntimeError: If pipeline is in inference mode.
107184
"""
108185
if self._is_inference():
109186
msg = "Pipeline in inference mode cannot be fitted"
@@ -115,53 +192,103 @@ def fit(self, dataset: Dataset) -> Context:
115192
context.configure_transformer(self.embedder_config)
116193
context.configure_transformer(self.cross_encoder_config)
117194

118-
self._fit(context, self.sampler)
195+
self.validate_modules(dataset, mode=incompatible_search_space)
196+
197+
test_utterances = context.data_handler.test_utterances()
198+
if test_utterances is None:
199+
self._logger.warning(
200+
"Test data is not provided. Final test metrics won't be calculated after pipeline optimization."
201+
)
202+
203+
if sampler is None:
204+
sampler = self.sampler
205+
206+
self._fit(context, sampler)
207+
208+
if context.is_ram_to_clear():
209+
nodes_configs = context.optimization_info.get_inference_nodes_config()
210+
nodes_list = [InferenceNode.from_config(cfg) for cfg in nodes_configs]
211+
else:
212+
modules_dict = context.optimization_info.get_best_modules()
213+
nodes_list = [InferenceNode(module, node_type) for node_type, module in modules_dict.items()]
214+
215+
self.nodes = {node.node_type: node for node in nodes_list}
216+
217+
if refit_after:
218+
# TODO reflect this refitting in dumped version of pipeline
219+
self._refit(context)
220+
221+
if test_utterances is not None:
222+
predictions = self.predict(test_utterances)
223+
for metric_name, metric in DECISION_METRICS.items():
224+
context.optimization_info.pipeline_metrics[metric_name] = metric(
225+
context.data_handler.test_labels(),
226+
predictions,
227+
)
228+
context.callback_handler.log_final_metrics(context.optimization_info.dump_evaluation_results())
229+
119230
return context
120231

121232
def validate_modules(self, dataset: Dataset, mode: SearchSpaceValidationMode) -> None:
122-
"""Validate nodes against a dataset.
233+
"""Validate modules with dataset.
123234
124235
Args:
125-
dataset: Dataset used for validation.
236+
dataset: Dataset for validation.
126237
mode: Validation mode.
127238
"""
128239
for node in self.nodes.values():
129240
if isinstance(node, NodeOptimizer):
130241
node.validate_nodes_with_dataset(dataset, mode)
131242

132-
def _is_inference(self) -> bool:
133-
"""Check whether the pipeline is in inference mode.
243+
@classmethod
244+
def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> "Pipeline":
245+
"""Create inference pipeline from dictionary config.
246+
247+
Args:
248+
nodes_configs: list of config for nodes
134249
135250
Returns:
136-
True if pipeline is in inference mode, otherwise False.
251+
Inference pipeline
137252
"""
138-
return isinstance(self.nodes[NodeType.scoring], InferenceNode)
253+
return cls.from_config([InferenceNodeConfig(**cfg) for cfg in nodes_configs])
254+
255+
@classmethod
256+
def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "Pipeline":
257+
"""Create inference pipeline from config.
258+
259+
Args:
260+
nodes_configs: list of config for nodes
261+
262+
Returns:
263+
Inference pipeline
264+
"""
265+
nodes = [InferenceNode.from_config(cfg) for cfg in nodes_configs]
266+
return cls(nodes)
139267

140268
@classmethod
141269
def load(cls, path: str | Path) -> "Pipeline":
142-
"""Load a pipeline from a given directory.
270+
"""Load pipeline in inference mode.
271+
272+
This method loads fitted modules and tuned hyperparameters.
143273
144274
Args:
145-
path: Path to the directory containing the pipeline configuration.
275+
path: Path to load
146276
147277
Returns:
148-
Loaded pipeline instance.
278+
Inference pipeline
149279
"""
150280
with (Path(path) / "inference_config.yaml").open() as file:
151281
inference_dict_config = yaml.safe_load(file)
152282
return cls.from_dict_config(inference_dict_config["nodes_configs"])
153283

154284
def predict(self, utterances: list[str]) -> ListOfGenericLabels:
155-
"""Predict labels for a list of utterances.
285+
"""Predict the labels for the utterances.
156286
157287
Args:
158-
utterances: List of utterances to predict labels for.
288+
utterances: list of utterances
159289
160290
Returns:
161-
ListOfGenericLabels: Predicted labels for the utterances.
162-
163-
Raises:
164-
RuntimeError: If the pipeline is not in inference mode.
291+
list of predicted labels
165292
"""
166293
if not self._is_inference():
167294
msg = "Pipeline in optimization mode cannot perform inference"
@@ -177,7 +304,10 @@ def _refit(self, context: Context) -> None:
177304
"""Fit pipeline of already selected modules with all train data.
178305
179306
Args:
180-
context: context object to take data from
307+
context: Context object.
308+
309+
Raises:
310+
RuntimeError: If pipeline is in optimization mode.
181311
"""
182312
if not self._is_inference():
183313
msg = "Pipeline in optimization mode cannot perform inference"
@@ -198,12 +328,8 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
198328
199329
Args:
200330
utterances: list of utterances
201-
202331
Returns:
203-
InferencePipelineOutput: prediction output
204-
205-
Raises:
206-
RuntimeError: If the pipeline is not in inference mode.
332+
Inference pipeline output
207333
"""
208334
if not self._is_inference():
209335
msg = "Pipeline in optimization mode cannot perform inference"
@@ -242,11 +368,11 @@ def make_report(logs: dict[str, Any], nodes: list[NodeType]) -> str:
242368
"""Generate a report from optimization logs.
243369
244370
Args:
245-
logs: Dictionary containing optimization logs.
371+
logs: Logs dictionary.
246372
nodes: List of node types.
247373
248374
Returns:
249-
Formatted report string.
375+
String report.
250376
"""
251377
ids = [np.argmax(logs["metrics"][node]) for node in nodes]
252378
configs = []

autointent/generation/utterances/balancer.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,15 @@ def __init__(
2424
async_mode: bool = False,
2525
max_samples_per_class: int | None = None,
2626
) -> None:
27-
"""
28-
Initialize the UtteranceBalancer.
27+
"""Initialize the UtteranceBalancer.
2928
3029
Args:
3130
generator (Generator): The generator object used to create utterances.
3231
prompt_maker (Callable[[Intent, int], list[Message]]): A callable that creates prompts for the generator.
33-
seed (int, optional): The seed for random number generation. Defaults to 42.
3432
async_mode (bool, optional): Whether to run the generator in asynchronous mode. Defaults to False.
3533
max_samples_per_class (int | None, optional): The maximum number of samples per class.
3634
Must be a positive integer or None. Defaults to None.
35+
3736
Raises:
3837
ValueError: If max_samples_per_class is not None and is less than or equal to 0.
3938
"""
@@ -47,12 +46,10 @@ def __init__(
4746
self.max_samples = max_samples_per_class
4847

4948
def balance(self, dataset: Dataset, split: str = Split.TRAIN, batch_size: int = 4) -> Dataset:
50-
"""
51-
Balances the specified dataset split.
49+
"""Balances the specified dataset split.
5250
5351
:param dataset: Source dataset
5452
:param split: Target split for balancing
55-
:param n_evolutions: Number of augmentations per example
5653
:param batch_size: Batch size for asynchronous processing
5754
:return: Balanced dataset
5855
"""
@@ -142,7 +139,11 @@ def _augment_class(self, dataset: Dataset, split: str, class_id: int, needed: in
142139
logger.debug("Total samples after augmentation: %s", final_count)
143140

144141
def _process_utterances(self, generated: list[str]) -> list[str]:
145-
"""Process and clean generated utterances."""
142+
"""Process and clean generated utterances.
143+
144+
Args:
145+
generated: Generated list
146+
"""
146147
processed = []
147148
for ut in generated:
148149
if "', '" in ut or "',\n" in ut:

0 commit comments

Comments
 (0)