Skip to content

Commit 071b38f

Browse files
authored
Vector index and embeddings caching (#89)
1 parent b71a301 commit 071b38f

File tree

11 files changed

+130
-222
lines changed

11 files changed

+130
-222
lines changed

autointent/_embedder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
device: str = "cpu",
6161
batch_size: int = 32,
6262
max_length: int | None = None,
63-
use_cache: bool = False,
63+
use_cache: bool = True,
6464
) -> None:
6565
"""
6666
Initialize the Embedder.
@@ -69,7 +69,7 @@ def __init__(
6969
:param device: Device to run the model on (e.g., "cpu", "cuda").
7070
:param batch_size: Batch size for embedding calculations.
7171
:param max_length: Maximum sequence length for the embedding model.
72-
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
72+
:param use_cache: Flag indicating whether to cache intermediate embeddings.
7373
"""
7474
self.model_name = model_name
7575
self.device = device

autointent/configs/_optimization_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class EmbedderConfig:
109109
"""Batch size for the embedder"""
110110
max_length: int | None = None
111111
"""Max length for the embedder. If None, the max length will be taken from model config"""
112-
use_cache: bool = False
112+
use_cache: bool = True
113113
"""Flag indicating whether to cache embeddings for reuse, improving performance in repeated operations."""
114114
device: str = "cpu"
115115
"""Device to use for the vector index. Can be 'cpu', 'cuda', 'cuda:0', 'mps', etc."""

autointent/context/vector_index_client/_vector_index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
embedder_device: str,
3232
embedder_batch_size: int = 32,
3333
embedder_max_length: int | None = None,
34-
embedder_use_cache: bool = False,
34+
embedder_use_cache: bool = True,
3535
) -> None:
3636
"""
3737
Initialize the vector index.
@@ -121,7 +121,7 @@ def _search_by_embedding(self, embedding: npt.NDArray[Any], k: int) -> list[list
121121
msg = "`embedding` should be a 2D array of shape (n_queries, dim_size)"
122122
raise ValueError(msg)
123123

124-
cos_sim, indices = self.index.search(embedding, k)
124+
cos_sim, indices = self.index.search(embedding, k) # TODO add caching similar to Embedder.embed() caching
125125
distances = 1 - cos_sim
126126

127127
results = []

autointent/context/vector_index_client/_vector_index_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
db_dir: str | Path | None,
3333
embedder_batch_size: int = 32,
3434
embedder_max_length: int | None = None,
35-
embedder_use_cache: bool = False,
35+
embedder_use_cache: bool = True,
3636
) -> None:
3737
"""
3838
Initialize the VectorIndexClient.

autointent/modules/embedding/_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
embedder_device: str = "cpu",
7373
batch_size: int = 32,
7474
max_length: int | None = None,
75-
embedder_use_cache: bool = False,
75+
embedder_use_cache: bool = True,
7676
) -> None:
7777
"""
7878
Initialize the RetrievalEmbedding.

autointent/modules/scoring/_description/description.py

Lines changed: 25 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,17 @@
1010
from sklearn.metrics.pairwise import cosine_similarity
1111

1212
from autointent import Context, Embedder
13-
from autointent.context.vector_index_client import VectorIndex, VectorIndexClient
1413
from autointent.custom_types import LabelType
1514
from autointent.modules.abc import ScoringModule
1615

1716

1817
class DescriptionScorerDumpMetadata(TypedDict):
1918
"""Metadata for dumping the state of a DescriptionScorer."""
2019

21-
db_dir: str
2220
n_classes: int
2321
multilabel: bool
24-
batch_size: int
25-
max_length: int | None
22+
embedder_batch_size: int
23+
embedder_max_length: int | None
2624

2725

2826
class DescriptionScorer(ScoringModule):
@@ -34,46 +32,40 @@ class DescriptionScorer(ScoringModule):
3432
3533
:ivar weights_file_name: Filename for saving the description vectors (`description_vectors.npy`).
3634
:ivar embedder: The embedder used to generate embeddings for utterances and descriptions.
37-
:ivar precomputed_embeddings: Flag indicating whether precomputed embeddings are used.
3835
:ivar embedding_model_subdir: Directory for storing the embedder's model files.
39-
:ivar _vector_index: Internal vector index used when embeddings are precomputed.
40-
:ivar db_dir: Directory path where the vector database is stored.
4136
:ivar name: Name of the scorer, defaults to "description".
4237
4338
"""
4439

4540
weights_file_name: str = "description_vectors.npy"
4641
embedder: Embedder
47-
precomputed_embeddings: bool = False
4842
embedding_model_subdir: str = "embedding_model"
49-
_vector_index: VectorIndex
50-
db_dir: str
5143
name = "description"
5244

5345
def __init__(
5446
self,
5547
embedder_name: str,
5648
temperature: float = 1.0,
5749
embedder_device: str = "cpu",
58-
batch_size: int = 32,
59-
max_length: int | None = None,
60-
embedder_use_cache: bool = False,
50+
embedder_batch_size: int = 32,
51+
embedder_max_length: int | None = None,
52+
embedder_use_cache: bool = True,
6153
) -> None:
6254
"""
6355
Initialize the DescriptionScorer.
6456
6557
:param embedder_name: Name of the embedder model.
6658
:param temperature: Temperature parameter for scaling logits, defaults to 1.0.
6759
:param embedder_device: Device to run the embedder on, e.g., "cpu" or "cuda".
68-
:param batch_size: Batch size for embedding generation, defaults to 32.
69-
:param max_length: Maximum sequence length for embedding, defaults to None.
60+
:param embedder_batch_size: Batch size for embedding generation, defaults to 32.
61+
:param embedder_max_length: Maximum sequence length for embedding, defaults to None.
7062
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
7163
"""
7264
self.temperature = temperature
7365
self.embedder_device = embedder_device
7466
self.embedder_name = embedder_name
75-
self.batch_size = batch_size
76-
self.max_length = max_length
67+
self.embedder_batch_size = embedder_batch_size
68+
self.embedder_max_length = embedder_max_length
7769
self.embedder_use_cache = embedder_use_cache
7870

7971
@classmethod
@@ -93,19 +85,15 @@ def from_context(
9385
"""
9486
if embedder_name is None:
9587
embedder_name = context.optimization_info.get_best_embedder()
96-
precomputed_embeddings = True
97-
else:
98-
precomputed_embeddings = context.vector_index_client.exists(embedder_name)
9988

100-
instance = cls(
89+
return cls(
10190
temperature=temperature,
10291
embedder_device=context.get_device(),
10392
embedder_name=embedder_name,
10493
embedder_use_cache=context.get_use_cache(),
94+
embedder_batch_size=context.get_batch_size(),
95+
embedder_max_length=context.get_max_length(),
10596
)
106-
instance.precomputed_embeddings = precomputed_embeddings
107-
instance.db_dir = str(context.get_db_dir())
108-
return instance
10997

11098
def get_embedder_name(self) -> str:
11199
"""
@@ -136,39 +124,22 @@ def fit(
136124
self.n_classes = len(set(labels))
137125
self.multilabel = False
138126

139-
if self.precomputed_embeddings:
140-
# this happens only when LinearScorer is within Pipeline opimization after RetrievalNode optimization
141-
vector_index_client = VectorIndexClient(
142-
self.embedder_device,
143-
self.db_dir,
144-
self.batch_size,
145-
self.max_length,
146-
self.embedder_use_cache,
147-
)
148-
vector_index = vector_index_client.get_index(self.embedder_name)
149-
features = vector_index.get_all_embeddings()
150-
if len(features) != len(utterances):
151-
msg = "Vector index mismatches provided utterances"
152-
raise ValueError(msg)
153-
embedder = vector_index.embedder
154-
else:
155-
embedder = Embedder(
156-
device=self.embedder_device,
157-
model_name=self.embedder_name,
158-
batch_size=self.batch_size,
159-
max_length=self.max_length,
160-
use_cache=self.embedder_use_cache,
161-
)
162-
features = embedder.embed(utterances)
163-
164127
if any(description is None for description in descriptions):
165128
error_text = (
166129
"Some intent descriptions (label_description) are missing (None). "
167130
"Please ensure all intents have descriptions."
168131
)
169132
raise ValueError(error_text)
170133

171-
self.description_vectors = embedder.embed([desc for desc in descriptions if desc])
134+
embedder = Embedder(
135+
device=self.embedder_device,
136+
model_name=self.embedder_name,
137+
batch_size=self.embedder_batch_size,
138+
max_length=self.embedder_max_length,
139+
use_cache=self.embedder_use_cache,
140+
)
141+
142+
self.description_vectors = embedder.embed(descriptions)
172143
self.embedder = embedder
173144

174145
def predict(self, utterances: list[str]) -> NDArray[np.float64]:
@@ -198,11 +169,10 @@ def dump(self, path: str) -> None:
198169
:param path: Path to the directory where assets will be dumped.
199170
"""
200171
self.metadata = DescriptionScorerDumpMetadata(
201-
db_dir=str(self.db_dir),
202172
n_classes=self.n_classes,
203173
multilabel=self.multilabel,
204-
batch_size=self.batch_size,
205-
max_length=self.max_length,
174+
embedder_batch_size=self.embedder_batch_size,
175+
embedder_max_length=self.embedder_max_length,
206176
)
207177

208178
dump_dir = Path(path)
@@ -232,7 +202,7 @@ def load(self, path: str) -> None:
232202
self.embedder = Embedder(
233203
device=self.embedder_device,
234204
model_name=embedder_dir,
235-
batch_size=self.metadata["batch_size"],
236-
max_length=self.metadata["max_length"],
205+
batch_size=self.metadata["embedder_batch_size"],
206+
max_length=self.metadata["embedder_max_length"],
237207
use_cache=self.embedder_use_cache,
238208
)

autointent/modules/scoring/_dnnc/dnnc.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ class DNNCScorer(ScoringModule):
5151
5252
:ivar crossencoder_subdir: Subdirectory for storing the cross-encoder model (`crossencoder`).
5353
:ivar model: The model used for scoring, which could be a `CrossEncoder` or a `CrossEncoderWithLogreg`.
54-
:ivar prebuilt_index: Flag indicating whether a prebuilt vector index is used.
5554
:ivar _db_dir: Path to the database directory where the vector index is stored.
5655
:ivar name: Name of the scorer, defaults to "dnnc".
5756
@@ -95,7 +94,6 @@ class DNNCScorer(ScoringModule):
9594

9695
crossencoder_subdir: str = "crossencoder"
9796
model: CrossEncoder | CrossEncoderWithLogreg
98-
prebuilt_index: bool = False
9997

10098
def __init__(
10199
self,
@@ -107,7 +105,7 @@ def __init__(
107105
train_head: bool = False,
108106
batch_size: int = 32,
109107
max_length: int | None = None,
110-
embedder_use_cache: bool = False,
108+
embedder_use_cache: bool = True,
111109
) -> None:
112110
"""
113111
Initialize the DNNCScorer.
@@ -164,11 +162,8 @@ def from_context(
164162
"""
165163
if embedder_name is None:
166164
embedder_name = context.optimization_info.get_best_embedder()
167-
prebuilt_index = True
168-
else:
169-
prebuilt_index = context.vector_index_client.exists(embedder_name)
170165

171-
instance = cls(
166+
return cls(
172167
cross_encoder_name=cross_encoder_name,
173168
embedder_name=embedder_name,
174169
k=k,
@@ -179,8 +174,6 @@ def from_context(
179174
max_length=context.get_max_length(),
180175
embedder_use_cache=context.get_use_cache(),
181176
)
182-
instance.prebuilt_index = prebuilt_index
183-
return instance
184177

185178
def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
186179
"""
@@ -195,15 +188,7 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
195188
self.model = CrossEncoder(self.cross_encoder_name, trust_remote_code=True, device=self.device)
196189

197190
vector_index_client = VectorIndexClient(self.device, self.db_dir, embedder_use_cache=self.embedder_use_cache)
198-
199-
if self.prebuilt_index:
200-
# this happens only when LinearScorer is within Pipeline opimization after RetrievalNode optimization
201-
self.vector_index = vector_index_client.get_index(self.embedder_name)
202-
if len(utterances) != len(self.vector_index.texts):
203-
msg = "Vector index mismatches provided utterances"
204-
raise ValueError(msg)
205-
else:
206-
self.vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels)
191+
self.vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels)
207192

208193
if self.train_head:
209194
model = CrossEncoderWithLogreg(self.model)

0 commit comments

Comments
 (0)