Skip to content

Commit 272e1d4

Browse files
authored
Add embeddings caching (#70)
1 parent f0b7885 commit 272e1d4

File tree

14 files changed

+209
-10
lines changed

14 files changed

+209
-10
lines changed

autointent/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ._embedder import Embedder
2+
from ._hash import Hasher
23
from .context import Context
34
from .context.data_handler import Dataset
45
from .pipeline import Pipeline
56

6-
__all__ = ["Context", "Dataset", "Embedder", "Pipeline"]
7+
__all__ = ["Context", "Dataset", "Embedder", "Hasher", "Pipeline"]

autointent/_embedder.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,27 @@
1212

1313
import numpy as np
1414
import numpy.typing as npt
15+
from appdirs import user_cache_dir
1516
from sentence_transformers import SentenceTransformer
1617

18+
from ._hash import Hasher
19+
20+
21+
def get_embeddings_path(filename: str) -> Path:
22+
"""
23+
Get the path to the embeddings file.
24+
25+
This function constructs the full path to an embeddings file stored
26+
in a specific directory under the user's home directory. The embeddings
27+
file is named based on the provided filename, with the `.npy` extension
28+
added.
29+
30+
:param filename: The name of the embeddings file (without extension).
31+
32+
:return: The full path to the embeddings file.
33+
"""
34+
return Path(user_cache_dir("autointent")) / "embeddings" / f"{filename}.npy"
35+
1736

1837
class EmbedderDumpMetadata(TypedDict):
1938
"""Metadata for saving and loading an Embedder instance."""
@@ -41,6 +60,7 @@ def __init__(
4160
device: str = "cpu",
4261
batch_size: int = 32,
4362
max_length: int | None = None,
63+
use_cache: bool = False,
4464
) -> None:
4565
"""
4666
Initialize the Embedder.
@@ -49,11 +69,13 @@ def __init__(
4969
:param device: Device to run the model on (e.g., "cpu", "cuda").
5070
:param batch_size: Batch size for embedding calculations.
5171
:param max_length: Maximum sequence length for the embedding model.
72+
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
5273
"""
5374
self.model_name = model_name
5475
self.device = device
5576
self.batch_size = batch_size
5677
self.max_length = max_length
78+
self.use_cache = use_cache
5779

5880
if Path(model_name).exists():
5981
self.load(model_name)
@@ -62,6 +84,18 @@ def __init__(
6284

6385
self.logger = logging.getLogger(__name__)
6486

87+
def __hash__(self) -> int:
88+
"""
89+
Compute a hash value for the Embedder.
90+
91+
:returns: The hash value of the Embedder.
92+
"""
93+
hasher = Hasher()
94+
for parameter in self.embedding_model.parameters():
95+
hasher.update(parameter.detach().cpu().numpy())
96+
hasher.update(self.max_length)
97+
return hasher.intdigest()
98+
6599
def clear_ram(self) -> None:
66100
"""Move the embedding model to CPU and delete it from memory."""
67101
self.logger.debug("Clearing embedder %s from memory", self.model_name)
@@ -114,18 +148,35 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
114148
:param utterances: List of input texts to calculate embeddings for.
115149
:return: A numpy array of embeddings.
116150
"""
151+
if self.use_cache:
152+
hasher = Hasher()
153+
hasher.update(self)
154+
hasher.update(utterances)
155+
156+
embeddings_path = get_embeddings_path(hasher.hexdigest())
157+
if embeddings_path.exists():
158+
return np.load(embeddings_path) # type: ignore[no-any-return]
159+
117160
self.logger.debug(
118161
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, device=%s",
119162
self.model_name,
120163
self.batch_size,
121164
str(self.max_length),
122165
self.device,
123166
)
167+
124168
if self.max_length is not None:
125169
self.embedding_model.max_seq_length = self.max_length
126-
return self.embedding_model.encode(
170+
171+
embeddings = self.embedding_model.encode(
127172
utterances,
128173
convert_to_numpy=True,
129174
batch_size=self.batch_size,
130175
normalize_embeddings=True,
131-
) # type: ignore[return-value]
176+
)
177+
178+
if self.use_cache:
179+
embeddings_path.parent.mkdir(parents=True, exist_ok=True)
180+
np.save(embeddings_path, embeddings)
181+
182+
return embeddings # type: ignore[return-value]

autointent/_hash.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""This module provides functionality for hashing data using the xxhash algorithm."""
2+
3+
import pickle
4+
from typing import Any
5+
6+
import xxhash
7+
8+
9+
class Hasher:
10+
"""
11+
A class that provides methods for hashing data using xxhash.
12+
13+
This class supports both a class-level method for generating hashes from
14+
any given value, as well as an instance-level method for progressively
15+
updating a hash state with new values.
16+
"""
17+
18+
def __init__(self) -> None:
19+
"""
20+
Initialize the Hasher instance and sets up the internal xxhash state.
21+
22+
This state will be used for progressively hashing values using the
23+
`update` method and obtaining the final digest using `hexdigest`.
24+
"""
25+
self._state = xxhash.xxh64()
26+
27+
@classmethod
28+
def hash(cls, value: Any) -> int: # noqa: ANN401
29+
"""
30+
Generate a hash for the given value using xxhash.
31+
32+
:param value: The value to be hashed. This can be any Python object.
33+
34+
:return: The resulting hash digest as a hexadecimal string.
35+
"""
36+
if hasattr(value, "__hash__") and value.__hash__ not in {None, object.__hash__}:
37+
return hash(value)
38+
return xxhash.xxh64(pickle.dumps(value)).intdigest()
39+
40+
def update(self, value: Any) -> None: # noqa: ANN401
41+
"""
42+
Update the internal hash state with the provided value.
43+
44+
This method will first hash the type of the value, then hash the value
45+
itself, and update the internal state accordingly.
46+
47+
:param value: The value to update the hash state with.
48+
"""
49+
self._state.update(str(type(value)).encode())
50+
self._state.update(str(self.hash(value)).encode())
51+
52+
def hexdigest(self) -> str:
53+
"""
54+
Return the current hash digest as a hexadecimal string.
55+
56+
This method should be called after one or more `update` calls to get
57+
the final hash result.
58+
59+
:return: The resulting hash digest as a hexadecimal string.
60+
"""
61+
return self._state.hexdigest()
62+
63+
def intdigest(self) -> int:
64+
"""
65+
Return the current hash digest as an integer.
66+
67+
This method should be called after one or more `update` calls to get
68+
the final hash result.
69+
70+
:return: The resulting hash digest as an integer.
71+
"""
72+
return self._state.intdigest()

autointent/configs/_optimization_cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ class EmbedderConfig:
107107
"""Batch size for the embedder"""
108108
max_length: int | None = None
109109
"""Max length for the embedder. If None, the max length will be taken from model config"""
110+
use_cache: bool = False
111+
"""Flag indicating whether to cache embeddings for reuse, improving performance in repeated operations."""
110112

111113

112114
@dataclass

autointent/context/_context.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def configure_vector_index(self, config: VectorIndexConfig, embedder_config: Emb
6868
self.vector_index_config.db_dir,
6969
self.embedder_config.batch_size,
7070
self.embedder_config.max_length,
71+
self.embedder_config.use_cache,
7172
)
7273

7374
def configure_data(self, config: DataConfig) -> None:
@@ -190,6 +191,14 @@ def get_max_length(self) -> int | None:
190191
"""
191192
return self.vector_index_client.embedder_max_length
192193

194+
def get_use_cache(self) -> bool:
195+
"""
196+
Check if caching is enabled for the embedder.
197+
198+
:return: True if caching is enabled, False otherwise.
199+
"""
200+
return self.vector_index_client.embedder_use_cache
201+
193202
def get_dump_dir(self) -> Path | None:
194203
"""
195204
Get the directory for saving dumped modules.

autointent/context/vector_index_client/_vector_index.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
device: str,
3232
embedder_batch_size: int = 32,
3333
embedder_max_length: int | None = None,
34+
embedder_use_cache: bool = False,
3435
) -> None:
3536
"""
3637
Initialize the vector index.
@@ -39,13 +40,15 @@ def __init__(
3940
:param device: Device for running the embedding model (e.g., "cpu", "cuda").
4041
:param embedder_batch_size: Batch size for the embedder.
4142
:param embedder_max_length: Maximum sequence length for the embedder.
43+
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
4244
"""
4345
self.model_name = model_name
4446
self.embedder = Embedder(
4547
model_name=model_name,
4648
batch_size=embedder_batch_size,
4749
device=device,
4850
max_length=embedder_max_length,
51+
use_cache=embedder_use_cache,
4952
)
5053
self.device = device
5154

autointent/context/vector_index_client/_vector_index_client.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
db_dir: str | Path | None,
3333
embedder_batch_size: int = 32,
3434
embedder_max_length: int | None = None,
35+
embedder_use_cache: bool = False,
3536
) -> None:
3637
"""
3738
Initialize the VectorIndexClient.
@@ -40,12 +41,14 @@ def __init__(
4041
:param db_dir: Directory for storing vector indexes. Defaults to a cache directory.
4142
:param embedder_batch_size: Batch size for the embedding model.
4243
:param embedder_max_length: Maximum sequence length for the embedding model.
44+
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
4345
"""
4446
self._logger = logging.getLogger(__name__)
4547
self.device = device
4648
self.db_dir = get_db_dir(db_dir)
4749
self.embedder_batch_size = embedder_batch_size
4850
self.embedder_max_length = embedder_max_length
51+
self.embedder_use_cache = embedder_use_cache
4952

5053
def create_index(
5154
self,
@@ -64,7 +67,13 @@ def create_index(
6467
"""
6568
self._logger.info("Creating index for model: %s", model_name)
6669

67-
index = VectorIndex(model_name, self.device, self.embedder_batch_size, self.embedder_max_length)
70+
index = VectorIndex(
71+
model_name,
72+
self.device,
73+
self.embedder_batch_size,
74+
self.embedder_max_length,
75+
self.embedder_use_cache,
76+
)
6877
if utterances is not None and labels is not None:
6978
index.add(utterances, labels)
7079
self.dump(index)
@@ -165,7 +174,13 @@ def get_index(self, model_name: str) -> VectorIndex:
165174
"""
166175
dirpath = self._get_index_dirpath(model_name)
167176
if dirpath is not None:
168-
index = VectorIndex(model_name, self.device, self.embedder_batch_size, self.embedder_max_length)
177+
index = VectorIndex(
178+
model_name,
179+
self.device,
180+
self.embedder_batch_size,
181+
self.embedder_max_length,
182+
self.embedder_use_cache,
183+
)
169184
index.load(dirpath)
170185
return index
171186

autointent/modules/retrieval/_vectordb.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
device: str = "cpu",
7171
batch_size: int = 32,
7272
max_length: int | None = None,
73+
embedder_use_cache: bool = False,
7374
) -> None:
7475
"""
7576
Initialize the VectorDBModule.
@@ -80,12 +81,14 @@ def __init__(
8081
:param device: Device to run operations on, e.g., "cpu" or "cuda".
8182
:param batch_size: Batch size for embedding generation.
8283
:param max_length: Maximum sequence length for embeddings. None if not set.
84+
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
8385
"""
8486
self.embedder_name = embedder_name
8587
self.device = device
8688
self._db_dir = db_dir
8789
self.batch_size = batch_size
8890
self.max_length = max_length
91+
self.embedder_use_cache = embedder_use_cache
8992

9093
super().__init__(k=k)
9194

@@ -111,6 +114,7 @@ def from_context(
111114
device=context.get_device(),
112115
batch_size=context.get_batch_size(),
113116
max_length=context.get_max_length(),
117+
embedder_use_cache=context.get_use_cache(),
114118
)
115119

116120
@property
@@ -136,6 +140,7 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
136140
self.db_dir,
137141
embedder_batch_size=self.batch_size,
138142
embedder_max_length=self.max_length,
143+
embedder_use_cache=self.embedder_use_cache,
139144
)
140145
self.vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels)
141146

@@ -209,6 +214,7 @@ def load(self, path: str) -> None:
209214
db_dir=self.metadata["db_dir"],
210215
embedder_batch_size=self.metadata["batch_size"],
211216
embedder_max_length=self.metadata["max_length"],
217+
embedder_use_cache=self.embedder_use_cache,
212218
)
213219
self.vector_index = vector_index_client.get_index(self.embedder_name)
214220

0 commit comments

Comments
 (0)