77import json
88import logging
99import shutil
10+ from functools import lru_cache
1011from pathlib import Path
1112from typing import TypedDict
1213
14+ import huggingface_hub
1315import numpy as np
1416import numpy .typing as npt
1517import torch
1618from appdirs import user_cache_dir
1719from sentence_transformers import SentenceTransformer
20+ from sentence_transformers .similarity_functions import SimilarityFunction
1821
1922from ._hash import Hasher
2023from .configs import EmbedderConfig , TaskTypeEnum
2124
25+ logger = logging .getLogger (__name__ )
2226
23- def get_embeddings_path (filename : str ) -> Path :
27+
28+ def _get_embeddings_path (filename : str ) -> Path :
2429 """Get the path to the embeddings file.
2530
2631 This function constructs the full path to an embeddings file stored
@@ -37,6 +42,23 @@ def get_embeddings_path(filename: str) -> Path:
3742 return Path (user_cache_dir ("autointent" )) / "embeddings" / f"{ filename } .npy"
3843
3944
45+ @lru_cache (maxsize = 128 )
46+ def _get_latest_commit_hash (model_name : str ) -> str :
47+ """Get the latest commit hash for a given Hugging Face model.
48+
49+ Args:
50+ model_name: The name of the model to get the latest commit hash for.
51+
52+ Returns:
53+ The latest commit hash for the given model name or the model name if the commit hash is not found.
54+ """
55+ commit_hash = huggingface_hub .model_info (model_name , revision = "main" ).sha
56+ if commit_hash is None :
57+ logger .warning ("No commit hash found for model %s" , model_name )
58+ return model_name
59+ return commit_hash
60+
61+
4062class EmbedderDumpMetadata (TypedDict ):
4163 """Metadata for saving and loading an Embedder instance."""
4264
@@ -63,7 +85,6 @@ class Embedder:
6385
6486 _metadata_dict_name : str = "metadata.json"
6587 _dump_dir : Path | None = None
66- config : EmbedderConfig
6788 embedding_model : SentenceTransformer
6889
6990 def __init__ (self , embedder_config : EmbedderConfig ) -> None :
@@ -74,34 +95,41 @@ def __init__(self, embedder_config: EmbedderConfig) -> None:
7495 """
7596 self .config = embedder_config
7697
77- self .embedding_model = SentenceTransformer (
78- self .config .model_name ,
79- device = self .config .device ,
80- prompts = embedder_config .get_prompt_config (),
81- similarity_fn_name = self .config .similarity_fn_name ,
82- trust_remote_code = self .config .trust_remote_code ,
83- )
84-
85- self ._logger = logging .getLogger (__name__ )
86-
8798 def __hash__ (self ) -> int :
8899 """Compute a hash value for the Embedder.
89100
90101 Returns:
91102 The hash value of the Embedder.
92103 """
93104 hasher = Hasher ()
94- for parameter in self .embedding_model .parameters ():
95- hasher .update (parameter .detach ().cpu ().numpy ())
105+ if self .config .freeze :
106+ commit_hash = _get_latest_commit_hash (self .config .model_name )
107+ hasher .update (commit_hash )
108+ else :
109+ self ._load_model ()
110+ for parameter in self .embedding_model .parameters ():
111+ hasher .update (parameter .detach ().cpu ().numpy ())
96112 hasher .update (self .config .tokenizer_config .max_length )
97113 return hasher .intdigest ()
98114
115+ def _load_model (self ) -> None :
116+ """Load sentence transformers model to device."""
117+ if not hasattr (self , "embedding_model" ):
118+ self .embedding_model = SentenceTransformer (
119+ self .config .model_name ,
120+ device = self .config .device ,
121+ prompts = self .config .get_prompt_config (),
122+ similarity_fn_name = self .config .similarity_fn_name ,
123+ trust_remote_code = self .config .trust_remote_code ,
124+ )
125+
99126 def clear_ram (self ) -> None :
100127 """Move the embedding model to CPU and delete it from memory."""
101- self ._logger .debug ("Clearing embedder %s from memory" , self .config .model_name )
102- self .embedding_model .cpu ()
103- del self .embedding_model
104- torch .cuda .empty_cache ()
128+ if hasattr (self , "embedding_model" ):
129+ logger .debug ("Clearing embedder %s from memory" , self .config .model_name )
130+ self .embedding_model .cpu ()
131+ del self .embedding_model
132+ torch .cuda .empty_cache ()
105133
106134 def delete (self ) -> None :
107135 """Delete the embedding model and its associated directory."""
@@ -165,11 +193,13 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
165193 hasher .update (self )
166194 hasher .update (utterances )
167195
168- embeddings_path = get_embeddings_path (hasher .hexdigest ())
196+ embeddings_path = _get_embeddings_path (hasher .hexdigest ())
169197 if embeddings_path .exists ():
170198 return np .load (embeddings_path ) # type: ignore[no-any-return]
171199
172- self ._logger .debug (
200+ self ._load_model ()
201+
202+ logger .debug (
173203 "Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s" ,
174204 self .config .model_name ,
175205 self .config .batch_size ,
@@ -200,11 +230,11 @@ def similarity(
200230 """Calculate similarity between two sets of embeddings.
201231
202232 Args:
203- embeddings1: First set of embeddings.
204- embeddings2: Second set of embeddings.
233+ embeddings1: First set of embeddings (size n) .
234+ embeddings2: Second set of embeddings (size m) .
205235
206236 Returns:
207- A numpy array of similarities.
237+ A numpy array of similarities (size n x m) .
208238 """
209- result = self .embedding_model . similarity ( embeddings1 , embeddings2 )
210- return result .detach ().cpu ().numpy ().astype (np .float32 )
239+ similarity_fn = SimilarityFunction . to_similarity_fn ( self .config . similarity_fn_name )
240+ return similarity_fn ( embeddings1 , embeddings2 ) .detach ().cpu ().numpy ().astype (np .float32 )
0 commit comments