Skip to content

Commit 3b77616

Browse files
voorhsSamoed
andauthored
Feat/open search (#259)
* outline new architecture for vector index wrapper * implement faiss backend * update vector index wrapper * upd embedder a little bit * implement opensearch backend * faiss bug fix * opensearch bug fix * change usage of vector index * typeddict bug fix * fix assert never issue * update vector index usage * run formatter * fix rerank scorer * fix dnnc * fix dumper test * try to fix some typing problems * bug fix * Update autointent/_wrappers/vector_index/vector_index.py Co-authored-by: Roman Solomatin <[email protected]> * fix typing * add vector index tests * add threadpool limits for faiss on macos * fix opensearch tests a little bit * `BaseBackend` -> `BaseIndexBackend` * upd reranker tests * fix typing * fix typing --------- Co-authored-by: Roman Solomatin <[email protected]>
1 parent ee655bb commit 3b77616

File tree

24 files changed

+1087
-353
lines changed

24 files changed

+1087
-353
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,10 @@
1818
HPOConfig,
1919
InferenceNodeConfig,
2020
LoggingConfig,
21+
VectorIndexConfig,
22+
get_default_vector_index_config,
2123
)
22-
from autointent.custom_types import (
23-
ListOfGenericLabels,
24-
NodeType,
25-
SearchSpacePreset,
26-
SearchSpaceValidationMode,
27-
)
24+
from autointent.custom_types import ListOfGenericLabels, NodeType, SearchSpacePreset, SearchSpaceValidationMode
2825
from autointent.metrics import DECISION_METRICS, DICISION_METRICS_MULTILABEL
2926
from autointent.nodes import InferenceNode, NodeOptimizer
3027
from autointent.utils import load_preset, load_search_space
@@ -64,11 +61,19 @@ def __init__(
6461
self.data_config = DataConfig()
6562
self.transformer_config = HFModelConfig()
6663
self.hpo_config = HPOConfig()
64+
self.vector_index_config = get_default_vector_index_config()
6765
elif not isinstance(nodes[0], InferenceNode):
6866
assert_never(nodes)
6967

7068
def set_config(
71-
self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig | HFModelConfig | HPOConfig
69+
self,
70+
config: LoggingConfig
71+
| EmbedderConfig
72+
| CrossEncoderConfig
73+
| DataConfig
74+
| HFModelConfig
75+
| HPOConfig
76+
| VectorIndexConfig,
7277
) -> None:
7378
"""Set the configuration for the pipeline.
7479
@@ -87,6 +92,8 @@ def set_config(
8792
self.transformer_config = config
8893
elif isinstance(config, HPOConfig):
8994
self.hpo_config = config
95+
elif isinstance(config, VectorIndexConfig):
96+
self.vector_index_config = config
9097
else:
9198
assert_never(config)
9299

@@ -203,6 +210,7 @@ def fit(
203210
context.configure_transformer(self.cross_encoder_config)
204211
context.configure_transformer(self.transformer_config)
205212
context.configure_hpo(self.hpo_config)
213+
context.configure_vector_index(self.vector_index_config)
206214

207215
self.validate_modules(dataset, mode=incompatible_search_space)
208216

autointent/_wrappers/embedder.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import shutil
1010
from functools import lru_cache
1111
from pathlib import Path
12-
from typing import TypedDict
1312

1413
import huggingface_hub
1514
import numpy as np
@@ -59,23 +58,6 @@ def _get_latest_commit_hash(model_name: str) -> str:
5958
return commit_hash
6059

6160

62-
class EmbedderDumpMetadata(TypedDict):
63-
"""Metadata for saving and loading an Embedder instance."""
64-
65-
model_name: str
66-
"""Name of the hugging face model or a local path to sentence transformers dump."""
67-
device: str | None
68-
"""Torch notation for CPU or CUDA."""
69-
batch_size: int
70-
"""Batch size used for embedding calculations."""
71-
max_length: int | None
72-
"""Maximum sequence length for the embedding model."""
73-
use_cache: bool
74-
"""Whether to use embeddings caching."""
75-
similarity_fn_name: str | None
76-
"""Name of the similarity function to use."""
77-
78-
7961
class Embedder:
8062
"""A wrapper for managing embedding models using :py:class:`sentence_transformers.SentenceTransformer`.
8163
@@ -85,7 +67,6 @@ class Embedder:
8567

8668
_metadata_dict_name: str = "metadata.json"
8769
_dump_dir: Path | None = None
88-
embedding_model: SentenceTransformer
8970

9071
def __init__(self, embedder_config: EmbedderConfig) -> None:
9172
"""Initialize the Embedder.
@@ -106,22 +87,25 @@ def _get_hash(self) -> int:
10687
commit_hash = _get_latest_commit_hash(self.config.model_name)
10788
hasher.update(commit_hash)
10889
else:
109-
self._load_model()
90+
self.embedding_model = self._load_model()
11091
for parameter in self.embedding_model.parameters():
11192
hasher.update(parameter.detach().cpu().numpy())
11293
hasher.update(self.config.tokenizer_config.max_length)
11394
return hasher.intdigest()
11495

115-
def _load_model(self) -> None:
96+
def _load_model(self) -> SentenceTransformer:
11697
"""Load sentence transformers model to device."""
11798
if not hasattr(self, "embedding_model"):
118-
self.embedding_model = SentenceTransformer(
99+
res = SentenceTransformer(
119100
self.config.model_name,
120101
device=self.config.device,
121102
prompts=self.config.get_prompt_config(),
122103
similarity_fn_name=self.config.similarity_fn_name,
123104
trust_remote_code=self.config.trust_remote_code,
124105
)
106+
else:
107+
res = self.embedding_model
108+
return res
125109

126110
def clear_ram(self) -> None:
127111
"""Move the embedding model to CPU and delete it from memory."""
@@ -144,17 +128,9 @@ def dump(self, path: Path) -> None:
144128
path: Path to the directory where the model will be saved.
145129
"""
146130
self._dump_dir = path
147-
metadata = EmbedderDumpMetadata(
148-
model_name=str(self.config.model_name),
149-
device=self.config.device,
150-
batch_size=self.config.batch_size,
151-
max_length=self.config.tokenizer_config.max_length,
152-
use_cache=self.config.use_cache,
153-
similarity_fn_name=self.config.similarity_fn_name,
154-
)
155131
path.mkdir(parents=True, exist_ok=True)
156132
with (path / self._metadata_dict_name).open("w") as file:
157-
json.dump(metadata, file, indent=4)
133+
json.dump(self.config.model_dump(mode="json"), file, indent=4)
158134

159135
@classmethod
160136
def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -> "Embedder":
@@ -165,12 +141,12 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
165141
override_config: one can override presaved settings
166142
"""
167143
with (Path(path) / cls._metadata_dict_name).open(encoding="utf-8") as file:
168-
metadata: EmbedderDumpMetadata = json.load(file)
144+
config = EmbedderConfig.model_validate_json(file.read())
169145

170146
if override_config is not None:
171-
kwargs = {**metadata, **override_config.model_dump(exclude_unset=True)}
147+
kwargs = {**config.model_dump(), **override_config.model_dump(exclude_unset=True)}
172148
else:
173-
kwargs = metadata # type: ignore[assignment]
149+
kwargs = config.model_dump()
174150

175151
max_length = kwargs.pop("max_length", None)
176152
if max_length is not None:
@@ -203,7 +179,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
203179
logger.debug("loading embeddings from %s", str(embeddings_path))
204180
return np.load(embeddings_path) # type: ignore[no-any-return]
205181

206-
self._load_model()
182+
self.embedding_model = self._load_model()
207183

208184
logger.debug(
209185
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s, prompt=%s",

autointent/_wrappers/ranker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch import nn
2222

2323
from autointent.configs import CrossEncoderConfig
24-
from autointent.custom_types import ListOfLabels
24+
from autointent.custom_types import ListOfLabels, RerankedItem
2525

2626
logger = logging.getLogger(__name__)
2727

@@ -194,7 +194,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
194194
pairs, labels_ = construct_samples(utterances, labels, balancing_factor=1)
195195
self._fit(pairs, labels_) # type: ignore[arg-type]
196196

197-
def predict(self, pairs: list[tuple[str, str]]) -> npt.NDArray[Any]:
197+
def predict(self, pairs: list[tuple[str, str]]) -> npt.NDArray[np.float32]:
198198
"""Predict probabilities of two utterances having the same intent label.
199199
200200
Args:
@@ -224,7 +224,7 @@ def rank(
224224
query: str,
225225
query_docs: list[str],
226226
top_k: int | None = None,
227-
) -> list[dict[str, Any]]:
227+
) -> list[RerankedItem]:
228228
"""Rank documents according to meaning closeness to the query.
229229
230230
Args:
@@ -241,8 +241,8 @@ def rank(
241241
if top_k is None:
242242
top_k = len(query_docs)
243243

244-
results = [{"corpus_id": i, "score": scores[i]} for i in range(len(query_docs))]
245-
results.sort(key=lambda x: x["score"], reverse=True)
244+
results = [RerankedItem(corpus_id=i, score=scores[i]) for i in range(len(query_docs))]
245+
results.sort(key=lambda x: x.score, reverse=True)
246246
return results[:top_k]
247247

248248
def save(self, path: str) -> None:

0 commit comments

Comments
 (0)