Skip to content

Commit 409f279

Browse files
authored
Refactor/only one pipeline should live (#62)
1 parent 53d1606 commit 409f279

39 files changed

+570
-746
lines changed

autointent/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ._embedder import Embedder
22
from .context import Context
33
from .context.data_handler import Dataset
4-
from .pipeline import InferencePipeline, PipelineOptimizer
4+
from .pipeline import Pipeline
55

6-
__all__ = ["Context", "Dataset", "Embedder", "InferencePipeline", "PipelineOptimizer"]
6+
__all__ = ["Context", "Dataset", "Embedder", "Pipeline"]

autointent/configs/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from ._inference_cli import InferenceConfig
2-
from ._inference_pipeline import InferencePipelineConfig
31
from ._node import InferenceNodeConfig, NodeOptimizerConfig
42
from ._optimization_cli import (
53
AugmentationConfig,
@@ -10,19 +8,15 @@
108
TaskConfig,
119
VectorIndexConfig,
1210
)
13-
from ._pipeline_optimizer import PipelineOptimizerConfig
1411

1512
__all__ = [
1613
"AugmentationConfig",
1714
"DataConfig",
1815
"EmbedderConfig",
19-
"InferenceConfig",
2016
"InferenceNodeConfig",
21-
"InferencePipelineConfig",
2217
"LoggingConfig",
2318
"NodeOptimizerConfig",
2419
"OptimizationConfig",
25-
"PipelineOptimizerConfig",
2620
"TaskConfig",
2721
"VectorIndexConfig",
2822
]

autointent/configs/_inference_cli.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

autointent/configs/_inference_pipeline.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

autointent/configs/_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66
from omegaconf import MISSING
77

8-
from autointent.custom_types import NodeType, NodeTypeType
8+
from autointent.custom_types import NodeType
99

1010

1111
@dataclass
1212
class InferenceNodeConfig:
1313
"""Configuration for the inference node."""
1414

15-
node_type: NodeTypeType = MISSING
15+
node_type: NodeType = MISSING
1616
"""Type of the node. Should be one of the NODE_TYPES"""
1717
module_type: str = MISSING # TODO: add custom type
1818
"""Type of the module. Should be one of the Module"""

autointent/configs/_pipeline_optimizer.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

autointent/context/optimization_info/_optimization_info.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from numpy.typing import NDArray
1212

1313
from autointent.configs import InferenceNodeConfig
14-
from autointent.custom_types import NODE_TYPES, NodeType, NodeTypeType
14+
from autointent.custom_types import NodeType
1515

1616
from ._data_models import Artifact, Artifacts, RetrieverArtifact, ScorerArtifact, Trial, Trials, TrialsIds
1717
from ._logger import get_logger
@@ -188,13 +188,13 @@ def get_best_oos_scores(self, split: Literal["train", "validation", "test"]) ->
188188
return best_scorer_artifact.oos_scores[split]
189189
return best_scorer_artifact.oos_scores
190190

191-
def dump_evaluation_results(self) -> dict[str, dict[str, list[float]]]:
191+
def dump_evaluation_results(self) -> dict[str, Any]:
192192
"""
193193
Dump evaluation results for all nodes.
194194
195195
:return: Dictionary containing metrics and configurations for all nodes.
196196
"""
197-
node_wise_metrics = {node_type: self._get_metrics_values(node_type) for node_type in NODE_TYPES}
197+
node_wise_metrics = {node_type: self._get_metrics_values(node_type) for node_type in NodeType}
198198
return {
199199
"metrics": node_wise_metrics,
200200
"configs": self.trials.model_dump(),
@@ -206,15 +206,15 @@ def get_inference_nodes_config(self) -> list[InferenceNodeConfig]:
206206
207207
:return: List of `InferenceNodeConfig` objects for inference nodes.
208208
"""
209-
trial_ids = [self._get_best_trial_idx(node_type) for node_type in NODE_TYPES]
209+
trial_ids = [self._get_best_trial_idx(node_type) for node_type in NodeType]
210210
res = []
211-
for idx, node_type in zip(trial_ids, NODE_TYPES, strict=True):
211+
for idx, node_type in zip(trial_ids, NodeType, strict=True):
212212
if idx is None:
213213
continue
214214
trial = self.trials.get_trial(node_type, idx)
215215
res.append(
216216
InferenceNodeConfig(
217-
node_type=node_type, # type: ignore[arg-type]
217+
node_type=node_type,
218218
module_type=trial.module_type,
219219
module_config=trial.module_params,
220220
load_path=trial.module_dump_dir,
@@ -234,11 +234,11 @@ def _get_best_module(self, node_type: str) -> "Module | None":
234234
return self.modules.get(node_type)[idx]
235235
return None
236236

237-
def get_best_modules(self) -> dict[NodeTypeType, "Module"]:
237+
def get_best_modules(self) -> dict[NodeType, "Module"]:
238238
"""
239239
Retrieve the best modules for all node types.
240240
241241
:return: Dictionary of the best modules for each node type.
242242
"""
243-
res = {nt: self._get_best_module(nt) for nt in NODE_TYPES}
244-
return {nt: m for nt, m in res.items() if m is not None} # type: ignore[misc]
243+
res = {nt: self._get_best_module(nt) for nt in NodeType}
244+
return {nt: m for nt, m in res.items() if m is not None}

autointent/custom_types.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@ class BaseMetadataDict(TypedDict):
5252
class NodeType(str, Enum):
5353
"""Enumeration of node types in the AutoIntent pipeline."""
5454

55+
regexp = "regexp"
5556
retrieval = "retrieval"
56-
prediction = "prediction"
5757
scoring = "scoring"
58-
regexp = "regexp"
59-
60-
61-
NODE_TYPES = [NodeType.retrieval.value, NodeType.prediction.value, NodeType.scoring.value, NodeType.regexp.value]
62-
NodeTypeType = Literal["retrieval", "prediction", "scoring", "regexp"]
58+
prediction = "prediction"
Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,27 @@
11
# TODO: make up a better and more versatile config
2-
nodes:
3-
- node_type: retrieval
4-
metric: retrieval_hit_rate
5-
search_space:
6-
- module_type: vector_db
7-
k: [10]
8-
embedder_name:
9-
- avsolatorio/GIST-small-Embedding-v0
10-
- infgrad/stella-base-en-v2
11-
- node_type: scoring
12-
metric: scoring_roc_auc
13-
search_space:
14-
- module_type: knn
15-
k: [1, 3, 5, 10]
16-
weights: ["uniform", "distance", "closest"]
17-
- module_type: linear
18-
- module_type: dnnc
19-
cross_encoder_name:
20-
- BAAI/bge-reranker-base
21-
- cross-encoder/ms-marco-MiniLM-L-6-v2
22-
k: [1, 3, 5, 10]
23-
- node_type: prediction
24-
metric: prediction_accuracy
25-
search_space:
26-
- module_type: threshold
27-
thresh: [0.5]
28-
- module_type: argmax
2+
- node_type: retrieval
3+
metric: retrieval_hit_rate
4+
search_space:
5+
- module_type: vector_db
6+
k: [10]
7+
embedder_name:
8+
- avsolatorio/GIST-small-Embedding-v0
9+
- infgrad/stella-base-en-v2
10+
- node_type: scoring
11+
metric: scoring_roc_auc
12+
search_space:
13+
- module_type: knn
14+
k: [1, 3, 5, 10]
15+
weights: ["uniform", "distance", "closest"]
16+
- module_type: linear
17+
- module_type: dnnc
18+
cross_encoder_name:
19+
- BAAI/bge-reranker-base
20+
- cross-encoder/ms-marco-MiniLM-L-6-v2
21+
k: [1, 3, 5, 10]
22+
- node_type: prediction
23+
metric: prediction_accuracy
24+
search_space:
25+
- module_type: threshold
26+
thresh: [0.5]
27+
- module_type: argmax
Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
# TODO: make up a better and more versatile config
2-
nodes:
3-
- node_type: retrieval
4-
metric: retrieval_hit_rate_intersecting
5-
search_space:
6-
- module_type: vector_db
7-
k: [10]
8-
embedder_name:
9-
- deepvk/USER-bge-m3
10-
- node_type: scoring
11-
metric: scoring_roc_auc
12-
search_space:
13-
- module_type: knn
14-
k: [3]
15-
weights: ["uniform", "distance", "closest"]
16-
- module_type: linear
17-
- node_type: prediction
18-
metric: prediction_accuracy
19-
search_space:
20-
- module_type: threshold
21-
thresh: [0.5]
22-
- module_type: adaptive
2+
- node_type: retrieval
3+
metric: retrieval_hit_rate_intersecting
4+
search_space:
5+
- module_type: vector_db
6+
k: [10]
7+
embedder_name:
8+
- deepvk/USER-bge-m3
9+
- node_type: scoring
10+
metric: scoring_roc_auc
11+
search_space:
12+
- module_type: knn
13+
k: [3]
14+
weights: ["uniform", "distance", "closest"]
15+
- module_type: linear
16+
- node_type: prediction
17+
metric: prediction_accuracy
18+
search_space:
19+
- module_type: threshold
20+
thresh: [0.5]
21+
- module_type: adaptive

0 commit comments

Comments
 (0)