1010from pathlib import Path
1111from typing import TypedDict
1212
13+ import huggingface_hub
1314import numpy as np
1415import numpy .typing as npt
1516import torch
2021from .configs import EmbedderConfig , TaskTypeEnum
2122
2223
23- def get_embeddings_path (filename : str ) -> Path :
24+ def _get_embeddings_path (filename : str ) -> Path :
2425 """Get the path to the embeddings file.
2526
2627 This function constructs the full path to an embeddings file stored
@@ -37,6 +38,18 @@ def get_embeddings_path(filename: str) -> Path:
3738 return Path (user_cache_dir ("autointent" )) / "embeddings" / f"{ filename } .npy"
3839
3940
41+ def _get_latest_commit_hash (model_name : str ) -> str :
42+ """Get the latest commit hash for a given Hugging Face model.
43+
44+ Args:
45+ model_name: The name of the model to get the latest commit hash for.
46+
47+ Returns:
48+ The latest commit hash for the given model name.
49+ """
50+ return huggingface_hub .model_info (model_name , revision = "main" ).sha
51+
52+
4053class EmbedderDumpMetadata (TypedDict ):
4154 """Metadata for saving and loading an Embedder instance."""
4255
@@ -63,7 +76,6 @@ class Embedder:
6376
6477 _metadata_dict_name : str = "metadata.json"
6578 _dump_dir : Path | None = None
66- config : EmbedderConfig
6779 embedding_model : SentenceTransformer
6880
6981 def __init__ (self , embedder_config : EmbedderConfig ) -> None :
@@ -74,14 +86,6 @@ def __init__(self, embedder_config: EmbedderConfig) -> None:
7486 """
7587 self .config = embedder_config
7688
77- self .embedding_model = SentenceTransformer (
78- self .config .model_name ,
79- device = self .config .device ,
80- prompts = embedder_config .get_prompt_config (),
81- similarity_fn_name = self .config .similarity_fn_name ,
82- trust_remote_code = self .config .trust_remote_code ,
83- )
84-
8589 self ._logger = logging .getLogger (__name__ )
8690
8791 def __hash__ (self ) -> int :
@@ -91,11 +95,27 @@ def __hash__(self) -> int:
9195 The hash value of the Embedder.
9296 """
9397 hasher = Hasher ()
94- for parameter in self .embedding_model .parameters ():
95- hasher .update (parameter .detach ().cpu ().numpy ())
98+ if self .config .freeze :
99+ commit_hash = _get_latest_commit_hash (self .config .model_name )
100+ hasher .update (commit_hash )
101+ else :
102+ self ._load_model ()
103+ for parameter in self .embedding_model .parameters ():
104+ hasher .update (parameter .detach ().cpu ().numpy ())
96105 hasher .update (self .config .tokenizer_config .max_length )
97106 return hasher .intdigest ()
98107
108+ def _load_model (self ) -> None :
109+ """Load sentence transformers model to device."""
110+ if not hasattr (self , "embedding_model" ):
111+ self .embedding_model = SentenceTransformer (
112+ self .config .model_name ,
113+ device = self .config .device ,
114+ prompts = self .config .get_prompt_config (),
115+ similarity_fn_name = self .config .similarity_fn_name ,
116+ trust_remote_code = self .config .trust_remote_code ,
117+ )
118+
99119 def clear_ram (self ) -> None :
100120 """Move the embedding model to CPU and delete it from memory."""
101121 self ._logger .debug ("Clearing embedder %s from memory" , self .config .model_name )
@@ -165,10 +185,12 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
165185 hasher .update (self )
166186 hasher .update (utterances )
167187
168- embeddings_path = get_embeddings_path (hasher .hexdigest ())
188+ embeddings_path = _get_embeddings_path (hasher .hexdigest ())
169189 if embeddings_path .exists ():
170190 return np .load (embeddings_path ) # type: ignore[no-any-return]
171191
192+ self ._load_model ()
193+
172194 self ._logger .debug (
173195 "Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s" ,
174196 self .config .model_name ,
0 commit comments