1212
1313import numpy as np
1414import numpy .typing as npt
15+ from appdirs import user_cache_dir
1516from sentence_transformers import SentenceTransformer
1617
18+ from ._hash import Hasher
19+
20+
21+ def get_embeddings_path (filename : str ) -> Path :
22+ """
23+ Get the path to the embeddings file.
24+
25+ This function constructs the full path to an embeddings file stored
26+ in a specific directory under the user's home directory. The embeddings
27+ file is named based on the provided filename, with the `.npy` extension
28+ added.
29+
30+ :param filename: The name of the embeddings file (without extension).
31+
32+ :return: The full path to the embeddings file.
33+ """
34+ return Path (user_cache_dir ("autointent" )) / "embeddings" / f"{ filename } .npy"
35+
1736
1837class EmbedderDumpMetadata (TypedDict ):
1938 """Metadata for saving and loading an Embedder instance."""
@@ -41,6 +60,7 @@ def __init__(
4160 device : str = "cpu" ,
4261 batch_size : int = 32 ,
4362 max_length : int | None = None ,
63+ use_cache : bool = False ,
4464 ) -> None :
4565 """
4666 Initialize the Embedder.
@@ -49,11 +69,13 @@ def __init__(
4969 :param device: Device to run the model on (e.g., "cpu", "cuda").
5070 :param batch_size: Batch size for embedding calculations.
5171 :param max_length: Maximum sequence length for the embedding model.
72+ :param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
5273 """
5374 self .model_name = model_name
5475 self .device = device
5576 self .batch_size = batch_size
5677 self .max_length = max_length
78+ self .use_cache = use_cache
5779
5880 if Path (model_name ).exists ():
5981 self .load (model_name )
@@ -62,6 +84,18 @@ def __init__(
6284
6385 self .logger = logging .getLogger (__name__ )
6486
87+ def __hash__ (self ) -> int :
88+ """
89+ Compute a hash value for the Embedder.
90+
91+ :returns: The hash value of the Embedder.
92+ """
93+ hasher = Hasher ()
94+ for parameter in self .embedding_model .parameters ():
95+ hasher .update (parameter .detach ().cpu ().numpy ())
96+ hasher .update (self .max_length )
97+ return hasher .intdigest ()
98+
6599 def clear_ram (self ) -> None :
66100 """Move the embedding model to CPU and delete it from memory."""
67101 self .logger .debug ("Clearing embedder %s from memory" , self .model_name )
@@ -114,18 +148,35 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
114148 :param utterances: List of input texts to calculate embeddings for.
115149 :return: A numpy array of embeddings.
116150 """
151+ if self .use_cache :
152+ hasher = Hasher ()
153+ hasher .update (self )
154+ hasher .update (utterances )
155+
156+ embeddings_path = get_embeddings_path (hasher .hexdigest ())
157+ if embeddings_path .exists ():
158+ return np .load (embeddings_path ) # type: ignore[no-any-return]
159+
117160 self .logger .debug (
118161 "Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, device=%s" ,
119162 self .model_name ,
120163 self .batch_size ,
121164 str (self .max_length ),
122165 self .device ,
123166 )
167+
124168 if self .max_length is not None :
125169 self .embedding_model .max_seq_length = self .max_length
126- return self .embedding_model .encode (
170+
171+ embeddings = self .embedding_model .encode (
127172 utterances ,
128173 convert_to_numpy = True ,
129174 batch_size = self .batch_size ,
130175 normalize_embeddings = True ,
131- ) # type: ignore[return-value]
176+ )
177+
178+ if self .use_cache :
179+ embeddings_path .parent .mkdir (parents = True , exist_ok = True )
180+ np .save (embeddings_path , embeddings )
181+
182+ return embeddings # type: ignore[return-value]
0 commit comments