Skip to content

Commit 738cdb1

Browse files
committed
change usage of vector index
1 parent c95519c commit 738cdb1

File tree

10 files changed

+60
-32
lines changed

10 files changed

+60
-32
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,10 @@
1818
HPOConfig,
1919
InferenceNodeConfig,
2020
LoggingConfig,
21+
VectorIndexConfig,
22+
get_default_vector_index_config,
2123
)
22-
from autointent.custom_types import (
23-
ListOfGenericLabels,
24-
NodeType,
25-
SearchSpacePreset,
26-
SearchSpaceValidationMode,
27-
)
24+
from autointent.custom_types import ListOfGenericLabels, NodeType, SearchSpacePreset, SearchSpaceValidationMode
2825
from autointent.metrics import DECISION_METRICS, DICISION_METRICS_MULTILABEL
2926
from autointent.nodes import InferenceNode, NodeOptimizer
3027
from autointent.utils import load_preset, load_search_space
@@ -64,11 +61,19 @@ def __init__(
6461
self.data_config = DataConfig()
6562
self.transformer_config = HFModelConfig()
6663
self.hpo_config = HPOConfig()
64+
self.vector_index_config = get_default_vector_index_config()
6765
elif not isinstance(nodes[0], InferenceNode):
6866
assert_never(nodes)
6967

7068
def set_config(
71-
self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig | HFModelConfig | HPOConfig
69+
self,
70+
config: LoggingConfig
71+
| EmbedderConfig
72+
| CrossEncoderConfig
73+
| DataConfig
74+
| HFModelConfig
75+
| HPOConfig
76+
| VectorIndexConfig,
7277
) -> None:
7378
"""Set the configuration for the pipeline.
7479
@@ -87,6 +92,8 @@ def set_config(
8792
self.transformer_config = config
8893
elif isinstance(config, HPOConfig):
8994
self.hpo_config = config
95+
elif isinstance(config, VectorIndexConfig):
96+
self.vector_index_config = config
9097
else:
9198
assert_never(config)
9299

@@ -203,6 +210,7 @@ def fit(
203210
context.configure_transformer(self.cross_encoder_config)
204211
context.configure_transformer(self.transformer_config)
205212
context.configure_hpo(self.hpo_config)
213+
context.configure_vector_index(self.vector_index_config)
206214

207215
self.validate_modules(dataset, mode=incompatible_search_space)
208216

autointent/_wrappers/vector_index/opensearch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def add(self, embeddings: NDArray[Any], documents: list[Document]) -> None:
107107
}
108108
)
109109

110-
111110
self._init_index()
112111

113112
# Use bulk API for efficient indexing

autointent/configs/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
TaskTypeEnum,
1212
TokenizerConfig,
1313
)
14-
from ._vector_index import FaissConfig, OpenSearchConfig, VectorIndexConfig
14+
from ._vector_index import FaissConfig, OpenSearchConfig, VectorIndexConfig, get_default_vector_index_config
1515

1616
__all__ = [
1717
"CrossEncoderConfig",
@@ -29,4 +29,5 @@
2929
"TorchTrainingConfig",
3030
"VectorIndexConfig",
3131
"VocabConfig",
32+
"get_default_vector_index_config",
3233
]

autointent/configs/_vector_index.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,7 @@ class OpenSearchConfig(VectorIndexConfig):
1818
hosts: list[OpenSearchHost]
1919
index_name: str | None = None
2020
kwargs: dict[str, Any] = Field(default_factory=dict) # TODO define set of options
21+
22+
23+
def get_default_vector_index_config() -> VectorIndexConfig:
24+
return FaissConfig()

autointent/context/_context.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,15 @@
88

99
from autointent import Dataset
1010
from autointent._callbacks import CallbackHandler, get_callbacks
11-
from autointent.configs import CrossEncoderConfig, DataConfig, EmbedderConfig, HFModelConfig, HPOConfig, LoggingConfig
11+
from autointent.configs import (
12+
CrossEncoderConfig,
13+
DataConfig,
14+
EmbedderConfig,
15+
HFModelConfig,
16+
HPOConfig,
17+
LoggingConfig,
18+
VectorIndexConfig,
19+
)
1220

1321
from .data_handler import DataHandler
1422
from .optimization_info import OptimizationInfo
@@ -74,6 +82,12 @@ def configure_hpo(self, config: HPOConfig) -> None:
7482
else:
7583
assert_never(config)
7684

85+
def configure_vector_index(self, config: VectorIndexConfig) -> None:
86+
if isinstance(config, VectorIndexConfig):
87+
self.vector_index_config = config
88+
else:
89+
assert_never(config)
90+
7791
def set_dataset(self, dataset: Dataset, config: DataConfig) -> None:
7892
"""Set the datasets for training, validation and testing.
7993

autointent/modules/embedding/_retrieval.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import PositiveInt
66

77
from autointent import Context, VectorIndex
8-
from autointent.configs import EmbedderConfig
8+
from autointent.configs import EmbedderConfig, VectorIndexConfig, get_default_vector_index_config
99
from autointent.context.optimization_info import EmbeddingArtifact
1010
from autointent.custom_types import ListOfLabels
1111
from autointent.metrics import RETRIEVAL_METRICS_MULTICLASS, RETRIEVAL_METRICS_MULTILABEL
@@ -47,11 +47,13 @@ class RetrievalAimedEmbedding(BaseEmbedding):
4747
def __init__(
4848
self,
4949
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
50+
vector_index_config: VectorIndexConfig | None = None,
5051
k: PositiveInt = 10,
5152
) -> None:
5253
self.k = k
5354
embedder_config = EmbedderConfig.from_search_config(embedder_config)
5455
self.embedder_config = embedder_config
56+
self.vector_index_config = vector_index_config or get_default_vector_index_config()
5557

5658
if self.k < 0 or not isinstance(self.k, int):
5759
msg = "`k` argument of `RetrievalAimedEmbedding` must be a positive int"
@@ -74,6 +76,7 @@ def from_context(
7476
return cls(
7577
k=k,
7678
embedder_config=embedder_config,
79+
vector_index_config=context.vector_index_config,
7780
)
7881

7982
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
@@ -85,9 +88,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
8588
"""
8689
self._validate_task(labels)
8790

88-
self._vector_index = VectorIndex(
89-
self.embedder_config,
90-
)
91+
self._vector_index = VectorIndex(self.embedder_config, config=self.vector_index_config)
9192
self._vector_index.add(utterances, labels)
9293

9394
def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:

autointent/modules/scoring/_dnnc/dnnc.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pydantic import PositiveInt
1010

1111
from autointent import Context, Ranker, VectorIndex
12-
from autointent.configs import CrossEncoderConfig, EmbedderConfig
12+
from autointent.configs import CrossEncoderConfig, EmbedderConfig, VectorIndexConfig, get_default_vector_index_config
1313
from autointent.custom_types import ListOfLabels
1414
from autointent.modules.base import BaseScorer
1515

@@ -64,10 +64,12 @@ def __init__(
6464
k: PositiveInt = 5,
6565
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None,
6666
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
67+
vector_index_config: VectorIndexConfig | None = None,
6768
) -> None:
6869
self.cross_encoder_config = CrossEncoderConfig.from_search_config(cross_encoder_config)
6970
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
7071
self.k = k
72+
self.vector_index_config = vector_index_config or get_default_vector_index_config()
7173

7274
if self.k < 0 or not isinstance(self.k, int):
7375
msg = "`k` argument of `DNNCScorer` must be a positive int"
@@ -99,6 +101,7 @@ def from_context(
99101
k=k,
100102
embedder_config=embedder_config,
101103
cross_encoder_config=cross_encoder_config,
104+
vector_index_config=context.vector_index_config,
102105
)
103106

104107
def get_implicit_initialization_params(self) -> dict[str, Any]:
@@ -119,7 +122,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
119122
"""
120123
self._validate_task(labels)
121124

122-
self._vector_index = VectorIndex(self.embedder_config)
125+
self._vector_index = VectorIndex(self.embedder_config, config=self.vector_index_config)
123126
self._vector_index.add(utterances, labels)
124127

125128
self._cross_encoder = Ranker(self.cross_encoder_config, output_range="sigmoid")

autointent/modules/scoring/_knn/knn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pydantic import PositiveInt
88

99
from autointent import Context, VectorIndex
10-
from autointent.configs import EmbedderConfig
10+
from autointent.configs import EmbedderConfig, VectorIndexConfig, get_default_vector_index_config
1111
from autointent.custom_types import ListOfLabels, WeightType
1212
from autointent.modules.base import BaseScorer
1313

@@ -59,10 +59,12 @@ def __init__(
5959
k: PositiveInt = 5,
6060
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
6161
weights: WeightType = "distance",
62+
vector_index_config: VectorIndexConfig | None = None,
6263
) -> None:
6364
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
6465
self.k = k
6566
self.weights = weights
67+
self.vector_index_config = vector_index_config or get_default_vector_index_config()
6668

6769
if self.k < 0 or not isinstance(self.k, int):
6870
msg = "`k` argument of `KNNScorer` must be a positive int"
@@ -92,9 +94,7 @@ def from_context(
9294
embedder_config = context.resolve_embedder()
9395

9496
return cls(
95-
embedder_config=embedder_config,
96-
k=k,
97-
weights=weights,
97+
embedder_config=embedder_config, k=k, weights=weights, vector_index_config=context.vector_index_config
9898
)
9999

100100
def get_implicit_initialization_params(self) -> dict[str, Any]:
@@ -113,7 +113,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
113113
"""
114114
self._validate_task(labels)
115115

116-
self._vector_index = VectorIndex(self.embedder_config)
116+
self._vector_index = VectorIndex(self.embedder_config, config=self.vector_index_config)
117117
self._vector_index.add(utterances, labels)
118118

119119
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:

autointent/modules/scoring/_knn/rerank_scorer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pydantic import PositiveInt
88

99
from autointent import Context, Ranker
10-
from autointent.configs import CrossEncoderConfig, EmbedderConfig
10+
from autointent.configs import CrossEncoderConfig, EmbedderConfig, VectorIndexConfig, get_default_vector_index_config
1111
from autointent.custom_types import ListOfLabels, WeightType
1212

1313
from .knn import KNNScorer
@@ -42,6 +42,7 @@ def __init__(
4242
m: PositiveInt | None = None,
4343
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None,
4444
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
45+
vector_index_config: VectorIndexConfig | None = None,
4546
) -> None:
4647
super().__init__(
4748
embedder_config=embedder_config,
@@ -50,6 +51,7 @@ def __init__(
5051
)
5152

5253
self.cross_encoder_config = CrossEncoderConfig.from_search_config(cross_encoder_config)
54+
self.vector_index_config = vector_index_config or get_default_vector_index_config()
5355

5456
self.m = k if m is None else m
5557
self.use_cross_encoder_scores = use_cross_encoder_scores
@@ -94,6 +96,7 @@ def from_context(
9496
use_cross_encoder_scores=use_cross_encoder_scores,
9597
embedder_config=embedder_config,
9698
cross_encoder_config=cross_encoder_config,
99+
vector_index_config=context.vector_index_config,
97100
)
98101

99102
def get_implicit_initialization_params(self) -> dict[str, Any]:

autointent/modules/scoring/_mlknn/mlknn.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing_extensions import assert_never
99

1010
from autointent import Context, VectorIndex
11-
from autointent.configs import EmbedderConfig
11+
from autointent.configs import EmbedderConfig, VectorIndexConfig, get_default_vector_index_config
1212
from autointent.custom_types import ListOfLabels
1313
from autointent.modules.base import BaseScorer
1414

@@ -67,11 +67,13 @@ def __init__(
6767
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
6868
s: float = 1.0,
6969
ignore_first_neighbours: int = 0,
70+
vector_index_config: VectorIndexConfig | None = None,
7071
) -> None:
7172
self.k = k
7273
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
7374
self.s = s
7475
self.ignore_first_neighbours = ignore_first_neighbours
76+
self.vector_index_config = vector_index_config or get_default_vector_index_config()
7577

7678
if self.k < 0 or not isinstance(self.k, int):
7779
msg = "`k` argument of `MLKnnScorer` must be a positive int"
@@ -109,6 +111,7 @@ def from_context(
109111
embedder_config=embedder_config,
110112
s=s,
111113
ignore_first_neighbours=ignore_first_neighbours,
114+
vector_index_config=context.vector_index_config,
112115
)
113116

114117
def get_implicit_initialization_params(self) -> dict[str, Any]:
@@ -127,15 +130,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
127130
"""
128131
self._validate_task(labels)
129132

130-
self._vector_index = VectorIndex(
131-
EmbedderConfig(
132-
model_name=self.embedder_config.model_name,
133-
device=self.embedder_config.device,
134-
batch_size=self.embedder_config.batch_size,
135-
tokenizer_config=self.embedder_config.tokenizer_config,
136-
use_cache=self.embedder_config.use_cache,
137-
),
138-
)
133+
self._vector_index = VectorIndex(embedder_config=self.embedder_config, config=self.vector_index_config)
139134
self._vector_index.add(utterances, labels)
140135

141136
self._features = self._vector_index.get_all_embeddings()

0 commit comments

Comments
 (0)