Skip to content

Commit 8a875b0

Browse files
committed
implement new hashing strategy
1 parent 097f5ed commit 8a875b0

File tree

2 files changed

+37
-14
lines changed

2 files changed

+37
-14
lines changed

autointent/_embedder.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pathlib import Path
1111
from typing import TypedDict
1212

13+
import huggingface_hub
1314
import numpy as np
1415
import numpy.typing as npt
1516
import torch
@@ -20,7 +21,7 @@
2021
from .configs import EmbedderConfig, TaskTypeEnum
2122

2223

23-
def get_embeddings_path(filename: str) -> Path:
24+
def _get_embeddings_path(filename: str) -> Path:
2425
"""Get the path to the embeddings file.
2526
2627
This function constructs the full path to an embeddings file stored
@@ -37,6 +38,18 @@ def get_embeddings_path(filename: str) -> Path:
3738
return Path(user_cache_dir("autointent")) / "embeddings" / f"{filename}.npy"
3839

3940

41+
def _get_latest_commit_hash(model_name: str) -> str:
42+
"""Get the latest commit hash for a given Hugging Face model.
43+
44+
Args:
45+
model_name: The name of the model to get the latest commit hash for.
46+
47+
Returns:
48+
The latest commit hash for the given model name.
49+
"""
50+
return huggingface_hub.model_info(model_name, revision="main").sha
51+
52+
4053
class EmbedderDumpMetadata(TypedDict):
4154
"""Metadata for saving and loading an Embedder instance."""
4255

@@ -63,7 +76,6 @@ class Embedder:
6376

6477
_metadata_dict_name: str = "metadata.json"
6578
_dump_dir: Path | None = None
66-
config: EmbedderConfig
6779
embedding_model: SentenceTransformer
6880

6981
def __init__(self, embedder_config: EmbedderConfig) -> None:
@@ -74,14 +86,6 @@ def __init__(self, embedder_config: EmbedderConfig) -> None:
7486
"""
7587
self.config = embedder_config
7688

77-
self.embedding_model = SentenceTransformer(
78-
self.config.model_name,
79-
device=self.config.device,
80-
prompts=embedder_config.get_prompt_config(),
81-
similarity_fn_name=self.config.similarity_fn_name,
82-
trust_remote_code=self.config.trust_remote_code,
83-
)
84-
8589
self._logger = logging.getLogger(__name__)
8690

8791
def __hash__(self) -> int:
@@ -91,11 +95,27 @@ def __hash__(self) -> int:
9195
The hash value of the Embedder.
9296
"""
9397
hasher = Hasher()
94-
for parameter in self.embedding_model.parameters():
95-
hasher.update(parameter.detach().cpu().numpy())
98+
if self.config.freeze:
99+
commit_hash = _get_latest_commit_hash(self.config.model_name)
100+
hasher.update(commit_hash)
101+
else:
102+
self._load_model()
103+
for parameter in self.embedding_model.parameters():
104+
hasher.update(parameter.detach().cpu().numpy())
96105
hasher.update(self.config.tokenizer_config.max_length)
97106
return hasher.intdigest()
98107

108+
def _load_model(self) -> None:
109+
"""Load sentence transformers model to device."""
110+
if not hasattr(self, "embedding_model"):
111+
self.embedding_model = SentenceTransformer(
112+
self.config.model_name,
113+
device=self.config.device,
114+
prompts=self.config.get_prompt_config(),
115+
similarity_fn_name=self.config.similarity_fn_name,
116+
trust_remote_code=self.config.trust_remote_code,
117+
)
118+
99119
def clear_ram(self) -> None:
100120
"""Move the embedding model to CPU and delete it from memory."""
101121
self._logger.debug("Clearing embedder %s from memory", self.config.model_name)
@@ -165,10 +185,12 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
165185
hasher.update(self)
166186
hasher.update(utterances)
167187

168-
embeddings_path = get_embeddings_path(hasher.hexdigest())
188+
embeddings_path = _get_embeddings_path(hasher.hexdigest())
169189
if embeddings_path.exists():
170190
return np.load(embeddings_path) # type: ignore[no-any-return]
171191

192+
self._load_model()
193+
172194
self._logger.debug(
173195
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s",
174196
self.config.model_name,

autointent/configs/_transformers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class EmbedderConfig(HFModelConfig):
6464
similarity_fn_name: str | None = Field(
6565
"cosine", description="Name of the similarity function to use (cosine, dot, euclidean, manhattan)."
6666
)
67+
use_cache: bool = Field(True, description="Whether to use embeddings caching.")
68+
freeze: bool = Field(True, description="Whether to freeze the model parameters.")
6769

6870
def get_prompt_config(self) -> dict[str, str] | None:
6971
"""Get the prompt config for the given prompt type.
@@ -111,7 +113,6 @@ def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # no
111113
return self.default_prompt
112114
assert_never(prompt_type)
113115

114-
use_cache: bool = Field(False, description="Whether to use embeddings caching.")
115116

116117

117118
class CrossEncoderConfig(HFModelConfig):

0 commit comments

Comments
 (0)