Skip to content

Commit cb9b2ea

Browse files
committed
pull dev
2 parents 714f910 + 3b77616 commit cb9b2ea

File tree

30 files changed

+1912
-356
lines changed

30 files changed

+1912
-356
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,10 @@
1818
HPOConfig,
1919
InferenceNodeConfig,
2020
LoggingConfig,
21+
VectorIndexConfig,
22+
get_default_vector_index_config,
2123
)
22-
from autointent.custom_types import (
23-
ListOfGenericLabels,
24-
NodeType,
25-
SearchSpacePreset,
26-
SearchSpaceValidationMode,
27-
)
24+
from autointent.custom_types import ListOfGenericLabels, NodeType, SearchSpacePreset, SearchSpaceValidationMode
2825
from autointent.metrics import DECISION_METRICS, DICISION_METRICS_MULTILABEL
2926
from autointent.nodes import InferenceNode, NodeOptimizer
3027
from autointent.utils import load_preset, load_search_space
@@ -64,11 +61,19 @@ def __init__(
6461
self.data_config = DataConfig()
6562
self.transformer_config = HFModelConfig()
6663
self.hpo_config = HPOConfig()
64+
self.vector_index_config = get_default_vector_index_config()
6765
elif not isinstance(nodes[0], InferenceNode):
6866
assert_never(nodes)
6967

7068
def set_config(
71-
self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig | HFModelConfig | HPOConfig
69+
self,
70+
config: LoggingConfig
71+
| EmbedderConfig
72+
| CrossEncoderConfig
73+
| DataConfig
74+
| HFModelConfig
75+
| HPOConfig
76+
| VectorIndexConfig,
7277
) -> None:
7378
"""Set the configuration for the pipeline.
7479
@@ -87,6 +92,8 @@ def set_config(
8792
self.transformer_config = config
8893
elif isinstance(config, HPOConfig):
8994
self.hpo_config = config
95+
elif isinstance(config, VectorIndexConfig):
96+
self.vector_index_config = config
9097
else:
9198
assert_never(config)
9299

@@ -203,6 +210,7 @@ def fit(
203210
context.configure_transformer(self.cross_encoder_config)
204211
context.configure_transformer(self.transformer_config)
205212
context.configure_hpo(self.hpo_config)
213+
context.configure_vector_index(self.vector_index_config)
206214

207215
self.validate_modules(dataset, mode=incompatible_search_space)
208216

autointent/_wrappers/embedder.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import tempfile
1111
from functools import lru_cache
1212
from pathlib import Path
13-
from typing import TypedDict
1413

1514
import huggingface_hub
1615
import numpy as np
@@ -65,23 +64,6 @@ def _get_latest_commit_hash(model_name: str) -> str:
6564
return commit_hash
6665

6766

68-
class EmbedderDumpMetadata(TypedDict):
69-
"""Metadata for saving and loading an Embedder instance."""
70-
71-
model_name: str
72-
"""Name of the hugging face model or a local path to sentence transformers dump."""
73-
device: str | None
74-
"""Torch notation for CPU or CUDA."""
75-
batch_size: int
76-
"""Batch size used for embedding calculations."""
77-
max_length: int | None
78-
"""Maximum sequence length for the embedding model."""
79-
use_cache: bool
80-
"""Whether to use embeddings caching."""
81-
similarity_fn_name: str | None
82-
"""Name of the similarity function to use."""
83-
84-
8567
class Embedder:
8668
"""A wrapper for managing embedding models using :py:class:`sentence_transformers.SentenceTransformer`.
8769
@@ -91,7 +73,6 @@ class Embedder:
9173

9274
_metadata_dict_name: str = "metadata.json"
9375
_dump_dir: Path | None = None
94-
embedding_model: SentenceTransformer
9576

9677
def __init__(self, embedder_config: EmbedderConfig) -> None:
9778
"""Initialize the Embedder.
@@ -112,22 +93,25 @@ def _get_hash(self) -> int:
11293
commit_hash = _get_latest_commit_hash(self.config.model_name)
11394
hasher.update(commit_hash)
11495
else:
115-
self._load_model()
96+
self.embedding_model = self._load_model()
11697
for parameter in self.embedding_model.parameters():
11798
hasher.update(parameter.detach().cpu().numpy())
11899
hasher.update(self.config.tokenizer_config.max_length)
119100
return hasher.intdigest()
120101

121-
def _load_model(self) -> None:
102+
def _load_model(self) -> SentenceTransformer:
122103
"""Load sentence transformers model to device."""
123104
if not hasattr(self, "embedding_model"):
124-
self.embedding_model = SentenceTransformer(
105+
res = SentenceTransformer(
125106
self.config.model_name,
126107
device=self.config.device,
127108
prompts=self.config.get_prompt_config(),
128109
similarity_fn_name=self.config.similarity_fn_name,
129110
trust_remote_code=self.config.trust_remote_code,
130111
)
112+
else:
113+
res = self.embedding_model
114+
return res
131115

132116
def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTuningConfig) -> None:
133117
"""Train the embedding model."""
@@ -199,17 +183,9 @@ def dump(self, path: Path) -> None:
199183
path: Path to the directory where the model will be saved.
200184
"""
201185
self._dump_dir = path
202-
metadata = EmbedderDumpMetadata(
203-
model_name=str(self.config.model_name),
204-
device=self.config.device,
205-
batch_size=self.config.batch_size,
206-
max_length=self.config.tokenizer_config.max_length,
207-
use_cache=self.config.use_cache,
208-
similarity_fn_name=self.config.similarity_fn_name,
209-
)
210186
path.mkdir(parents=True, exist_ok=True)
211187
with (path / self._metadata_dict_name).open("w") as file:
212-
json.dump(metadata, file, indent=4)
188+
json.dump(self.config.model_dump(mode="json"), file, indent=4)
213189

214190
@classmethod
215191
def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -> "Embedder":
@@ -220,12 +196,12 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
220196
override_config: one can override presaved settings
221197
"""
222198
with (Path(path) / cls._metadata_dict_name).open(encoding="utf-8") as file:
223-
metadata: EmbedderDumpMetadata = json.load(file)
199+
config = EmbedderConfig.model_validate_json(file.read())
224200

225201
if override_config is not None:
226-
kwargs = {**metadata, **override_config.model_dump(exclude_unset=True)}
202+
kwargs = {**config.model_dump(), **override_config.model_dump(exclude_unset=True)}
227203
else:
228-
kwargs = metadata # type: ignore[assignment]
204+
kwargs = config.model_dump()
229205

230206
max_length = kwargs.pop("max_length", None)
231207
if max_length is not None:
@@ -258,7 +234,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
258234
logger.debug("loading embeddings from %s", str(embeddings_path))
259235
return np.load(embeddings_path) # type: ignore[no-any-return]
260236

261-
self._load_model()
237+
self.embedding_model = self._load_model()
262238

263239
logger.debug(
264240
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s, prompt=%s",

autointent/_wrappers/ranker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch import nn
2222

2323
from autointent.configs import CrossEncoderConfig
24-
from autointent.custom_types import ListOfLabels
24+
from autointent.custom_types import ListOfLabels, RerankedItem
2525

2626
logger = logging.getLogger(__name__)
2727

@@ -194,7 +194,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
194194
pairs, labels_ = construct_samples(utterances, labels, balancing_factor=1)
195195
self._fit(pairs, labels_) # type: ignore[arg-type]
196196

197-
def predict(self, pairs: list[tuple[str, str]]) -> npt.NDArray[Any]:
197+
def predict(self, pairs: list[tuple[str, str]]) -> npt.NDArray[np.float32]:
198198
"""Predict probabilities of two utterances having the same intent label.
199199
200200
Args:
@@ -224,7 +224,7 @@ def rank(
224224
query: str,
225225
query_docs: list[str],
226226
top_k: int | None = None,
227-
) -> list[dict[str, Any]]:
227+
) -> list[RerankedItem]:
228228
"""Rank documents according to meaning closeness to the query.
229229
230230
Args:
@@ -241,8 +241,8 @@ def rank(
241241
if top_k is None:
242242
top_k = len(query_docs)
243243

244-
results = [{"corpus_id": i, "score": scores[i]} for i in range(len(query_docs))]
245-
results.sort(key=lambda x: x["score"], reverse=True)
244+
results = [RerankedItem(corpus_id=i, score=scores[i]) for i in range(len(query_docs))]
245+
results.sort(key=lambda x: x.score, reverse=True)
246246
return results[:top_k]
247247

248248
def save(self, path: str) -> None:

0 commit comments

Comments
 (0)