Skip to content

Commit 389eb40

Browse files
Refactor/embedding caching (#195)
* implement new hashing strategy * fix codestyle * Update optimizer_config.schema.json * minor bug fix * fix typing error * refactor similarity calculation * Update optimizer_config.schema.json * upd callback test * solve 429 error --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 097f5ed commit 389eb40

File tree

5 files changed

+83
-44
lines changed

5 files changed

+83
-44
lines changed

autointent/_embedder.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,25 @@
77
import json
88
import logging
99
import shutil
10+
from functools import lru_cache
1011
from pathlib import Path
1112
from typing import TypedDict
1213

14+
import huggingface_hub
1315
import numpy as np
1416
import numpy.typing as npt
1517
import torch
1618
from appdirs import user_cache_dir
1719
from sentence_transformers import SentenceTransformer
20+
from sentence_transformers.similarity_functions import SimilarityFunction
1821

1922
from ._hash import Hasher
2023
from .configs import EmbedderConfig, TaskTypeEnum
2124

25+
logger = logging.getLogger(__name__)
2226

23-
def get_embeddings_path(filename: str) -> Path:
27+
28+
def _get_embeddings_path(filename: str) -> Path:
2429
"""Get the path to the embeddings file.
2530
2631
This function constructs the full path to an embeddings file stored
@@ -37,6 +42,23 @@ def get_embeddings_path(filename: str) -> Path:
3742
return Path(user_cache_dir("autointent")) / "embeddings" / f"{filename}.npy"
3843

3944

45+
@lru_cache(maxsize=128)
46+
def _get_latest_commit_hash(model_name: str) -> str:
47+
"""Get the latest commit hash for a given Hugging Face model.
48+
49+
Args:
50+
model_name: The name of the model to get the latest commit hash for.
51+
52+
Returns:
53+
The latest commit hash for the given model name or the model name if the commit hash is not found.
54+
"""
55+
commit_hash = huggingface_hub.model_info(model_name, revision="main").sha
56+
if commit_hash is None:
57+
logger.warning("No commit hash found for model %s", model_name)
58+
return model_name
59+
return commit_hash
60+
61+
4062
class EmbedderDumpMetadata(TypedDict):
4163
"""Metadata for saving and loading an Embedder instance."""
4264

@@ -63,7 +85,6 @@ class Embedder:
6385

6486
_metadata_dict_name: str = "metadata.json"
6587
_dump_dir: Path | None = None
66-
config: EmbedderConfig
6788
embedding_model: SentenceTransformer
6889

6990
def __init__(self, embedder_config: EmbedderConfig) -> None:
@@ -74,34 +95,41 @@ def __init__(self, embedder_config: EmbedderConfig) -> None:
7495
"""
7596
self.config = embedder_config
7697

77-
self.embedding_model = SentenceTransformer(
78-
self.config.model_name,
79-
device=self.config.device,
80-
prompts=embedder_config.get_prompt_config(),
81-
similarity_fn_name=self.config.similarity_fn_name,
82-
trust_remote_code=self.config.trust_remote_code,
83-
)
84-
85-
self._logger = logging.getLogger(__name__)
86-
8798
def __hash__(self) -> int:
8899
"""Compute a hash value for the Embedder.
89100
90101
Returns:
91102
The hash value of the Embedder.
92103
"""
93104
hasher = Hasher()
94-
for parameter in self.embedding_model.parameters():
95-
hasher.update(parameter.detach().cpu().numpy())
105+
if self.config.freeze:
106+
commit_hash = _get_latest_commit_hash(self.config.model_name)
107+
hasher.update(commit_hash)
108+
else:
109+
self._load_model()
110+
for parameter in self.embedding_model.parameters():
111+
hasher.update(parameter.detach().cpu().numpy())
96112
hasher.update(self.config.tokenizer_config.max_length)
97113
return hasher.intdigest()
98114

115+
def _load_model(self) -> None:
116+
"""Load sentence transformers model to device."""
117+
if not hasattr(self, "embedding_model"):
118+
self.embedding_model = SentenceTransformer(
119+
self.config.model_name,
120+
device=self.config.device,
121+
prompts=self.config.get_prompt_config(),
122+
similarity_fn_name=self.config.similarity_fn_name,
123+
trust_remote_code=self.config.trust_remote_code,
124+
)
125+
99126
def clear_ram(self) -> None:
100127
"""Move the embedding model to CPU and delete it from memory."""
101-
self._logger.debug("Clearing embedder %s from memory", self.config.model_name)
102-
self.embedding_model.cpu()
103-
del self.embedding_model
104-
torch.cuda.empty_cache()
128+
if hasattr(self, "embedding_model"):
129+
logger.debug("Clearing embedder %s from memory", self.config.model_name)
130+
self.embedding_model.cpu()
131+
del self.embedding_model
132+
torch.cuda.empty_cache()
105133

106134
def delete(self) -> None:
107135
"""Delete the embedding model and its associated directory."""
@@ -165,11 +193,13 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
165193
hasher.update(self)
166194
hasher.update(utterances)
167195

168-
embeddings_path = get_embeddings_path(hasher.hexdigest())
196+
embeddings_path = _get_embeddings_path(hasher.hexdigest())
169197
if embeddings_path.exists():
170198
return np.load(embeddings_path) # type: ignore[no-any-return]
171199

172-
self._logger.debug(
200+
self._load_model()
201+
202+
logger.debug(
173203
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s",
174204
self.config.model_name,
175205
self.config.batch_size,
@@ -200,11 +230,11 @@ def similarity(
200230
"""Calculate similarity between two sets of embeddings.
201231
202232
Args:
203-
embeddings1: First set of embeddings.
204-
embeddings2: Second set of embeddings.
233+
embeddings1: First set of embeddings (size n).
234+
embeddings2: Second set of embeddings (size m).
205235
206236
Returns:
207-
A numpy array of similarities.
237+
A numpy array of similarities (size n x m).
208238
"""
209-
result = self.embedding_model.similarity(embeddings1, embeddings2)
210-
return result.detach().cpu().numpy().astype(np.float32)
239+
similarity_fn = SimilarityFunction.to_similarity_fn(self.config.similarity_fn_name)
240+
return similarity_fn(embeddings1, embeddings2).detach().cpu().numpy().astype(np.float32)

autointent/configs/_transformers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ class EmbedderConfig(HFModelConfig):
6161
sts_prompt: str | None = Field(None, description="Prompt for finding most similar sentences.")
6262
query_prompt: str | None = Field(None, description="Prompt for query.")
6363
passage_prompt: str | None = Field(None, description="Prompt for passage.")
64-
similarity_fn_name: str | None = Field(
65-
"cosine", description="Name of the similarity function to use (cosine, dot, euclidean, manhattan)."
64+
similarity_fn_name: Literal["cosine", "dot", "euclidean", "manhattan"] = Field(
65+
"cosine", description="Name of the similarity function to use."
6666
)
67+
use_cache: bool = Field(True, description="Whether to use embeddings caching.")
68+
freeze: bool = Field(True, description="Whether to freeze the model parameters.")
6769

6870
def get_prompt_config(self) -> dict[str, str] | None:
6971
"""Get the prompt config for the given prompt type.
@@ -111,8 +113,6 @@ def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # no
111113
return self.default_prompt
112114
assert_never(prompt_type)
113115

114-
use_cache: bool = Field(False, description="Whether to use embeddings caching.")
115-
116116

117117
class CrossEncoderConfig(HFModelConfig):
118118
model_name: str = Field("cross-encoder/ms-marco-MiniLM-L6-v2", description="Name of the hugging face model.")

docs/optimizer_config.schema.json

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -226,23 +226,28 @@
226226
"title": "Passage Prompt"
227227
},
228228
"similarity_fn_name": {
229-
"anyOf": [
230-
{
231-
"type": "string"
232-
},
233-
{
234-
"type": "null"
235-
}
236-
],
237229
"default": "cosine",
238-
"description": "Name of the similarity function to use (cosine, dot, euclidean, manhattan).",
239-
"title": "Similarity Fn Name"
230+
"description": "Name of the similarity function to use.",
231+
"enum": [
232+
"cosine",
233+
"dot",
234+
"euclidean",
235+
"manhattan"
236+
],
237+
"title": "Similarity Fn Name",
238+
"type": "string"
240239
},
241240
"use_cache": {
242-
"default": false,
241+
"default": true,
243242
"description": "Whether to use embeddings caching.",
244243
"title": "Use Cache",
245244
"type": "boolean"
245+
},
246+
"freeze": {
247+
"default": true,
248+
"description": "Whether to freeze the model parameters.",
249+
"title": "Freeze",
250+
"type": "boolean"
246251
}
247252
},
248253
"title": "EmbedderConfig",
@@ -418,7 +423,8 @@
418423
"query_prompt": null,
419424
"passage_prompt": null,
420425
"similarity_fn_name": "cosine",
421-
"use_cache": false
426+
"use_cache": true,
427+
"freeze": true
422428
}
423429
},
424430
"cross_encoder_config": {

tests/callback/test_callback.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,13 @@ def test_pipeline_callbacks(dataset):
140140
"cluster_prompt": None,
141141
"default_prompt": None,
142142
"device": None,
143+
"freeze": True,
143144
"tokenizer_config": {"max_length": None, "truncation": True, "padding": True},
144145
"model_name": "sergeyzh/rubert-tiny-turbo",
145146
"passage_prompt": None,
146147
"query_prompt": None,
147148
"sts_prompt": None,
148-
"use_cache": False,
149+
"use_cache": True,
149150
"similarity_fn_name": "cosine",
150151
"trust_remote_code": False,
151152
},
@@ -176,12 +177,13 @@ def test_pipeline_callbacks(dataset):
176177
"cluster_prompt": None,
177178
"default_prompt": None,
178179
"device": None,
180+
"freeze": True,
179181
"tokenizer_config": {"max_length": None, "truncation": True, "padding": True},
180182
"model_name": "sergeyzh/rubert-tiny-turbo",
181183
"passage_prompt": None,
182184
"query_prompt": None,
183185
"sts_prompt": None,
184-
"use_cache": False,
186+
"use_cache": True,
185187
"similarity_fn_name": "cosine",
186188
"trust_remote_code": False,
187189
},
@@ -212,12 +214,13 @@ def test_pipeline_callbacks(dataset):
212214
"cluster_prompt": None,
213215
"default_prompt": None,
214216
"device": None,
217+
"freeze": True,
215218
"tokenizer_config": {"max_length": None, "truncation": True, "padding": True},
216219
"model_name": "sergeyzh/rubert-tiny-turbo",
217220
"passage_prompt": None,
218221
"query_prompt": None,
219222
"sts_prompt": None,
220-
"use_cache": False,
223+
"use_cache": True,
221224
"similarity_fn_name": "cosine",
222225
"trust_remote_code": False,
223226
},

user_guides/advanced/02_automl.py

100755100644
File mode changed.

0 commit comments

Comments
 (0)