1010import tempfile
1111from functools import lru_cache
1212from pathlib import Path
13- from typing import TypedDict
1413
1514import huggingface_hub
1615import numpy as np
@@ -65,23 +64,6 @@ def _get_latest_commit_hash(model_name: str) -> str:
6564 return commit_hash
6665
6766
68- class EmbedderDumpMetadata (TypedDict ):
69- """Metadata for saving and loading an Embedder instance."""
70-
71- model_name : str
72- """Name of the hugging face model or a local path to sentence transformers dump."""
73- device : str | None
74- """Torch notation for CPU or CUDA."""
75- batch_size : int
76- """Batch size used for embedding calculations."""
77- max_length : int | None
78- """Maximum sequence length for the embedding model."""
79- use_cache : bool
80- """Whether to use embeddings caching."""
81- similarity_fn_name : str | None
82- """Name of the similarity function to use."""
83-
84-
8567class Embedder :
8668 """A wrapper for managing embedding models using :py:class:`sentence_transformers.SentenceTransformer`.
8769
@@ -91,7 +73,6 @@ class Embedder:
9173
9274 _metadata_dict_name : str = "metadata.json"
9375 _dump_dir : Path | None = None
94- embedding_model : SentenceTransformer
9576
9677 def __init__ (self , embedder_config : EmbedderConfig ) -> None :
9778 """Initialize the Embedder.
@@ -112,22 +93,25 @@ def _get_hash(self) -> int:
11293 commit_hash = _get_latest_commit_hash (self .config .model_name )
11394 hasher .update (commit_hash )
11495 else :
115- self ._load_model ()
96+ self .embedding_model = self . _load_model ()
11697 for parameter in self .embedding_model .parameters ():
11798 hasher .update (parameter .detach ().cpu ().numpy ())
11899 hasher .update (self .config .tokenizer_config .max_length )
119100 return hasher .intdigest ()
120101
121- def _load_model (self ) -> None :
102+ def _load_model (self ) -> SentenceTransformer :
122103 """Load sentence transformers model to device."""
123104 if not hasattr (self , "embedding_model" ):
124- self . embedding_model = SentenceTransformer (
105+ res = SentenceTransformer (
125106 self .config .model_name ,
126107 device = self .config .device ,
127108 prompts = self .config .get_prompt_config (),
128109 similarity_fn_name = self .config .similarity_fn_name ,
129110 trust_remote_code = self .config .trust_remote_code ,
130111 )
112+ else :
113+ res = self .embedding_model
114+ return res
131115
132116 def train (self , utterances : list [str ], labels : list [int ], config : EmbedderFineTuningConfig ) -> None :
133117 """Train the embedding model."""
@@ -199,17 +183,9 @@ def dump(self, path: Path) -> None:
199183 path: Path to the directory where the model will be saved.
200184 """
201185 self ._dump_dir = path
202- metadata = EmbedderDumpMetadata (
203- model_name = str (self .config .model_name ),
204- device = self .config .device ,
205- batch_size = self .config .batch_size ,
206- max_length = self .config .tokenizer_config .max_length ,
207- use_cache = self .config .use_cache ,
208- similarity_fn_name = self .config .similarity_fn_name ,
209- )
210186 path .mkdir (parents = True , exist_ok = True )
211187 with (path / self ._metadata_dict_name ).open ("w" ) as file :
212- json .dump (metadata , file , indent = 4 )
188+ json .dump (self . config . model_dump ( mode = "json" ) , file , indent = 4 )
213189
214190 @classmethod
215191 def load (cls , path : Path | str , override_config : EmbedderConfig | None = None ) -> "Embedder" :
@@ -220,12 +196,12 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
220196 override_config: one can override presaved settings
221197 """
222198 with (Path (path ) / cls ._metadata_dict_name ).open (encoding = "utf-8" ) as file :
223- metadata : EmbedderDumpMetadata = json . load (file )
199+ config = EmbedderConfig . model_validate_json (file . read () )
224200
225201 if override_config is not None :
226- kwargs = {** metadata , ** override_config .model_dump (exclude_unset = True )}
202+ kwargs = {** config . model_dump () , ** override_config .model_dump (exclude_unset = True )}
227203 else :
228- kwargs = metadata # type: ignore[assignment]
204+ kwargs = config . model_dump ()
229205
230206 max_length = kwargs .pop ("max_length" , None )
231207 if max_length is not None :
@@ -258,7 +234,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
258234 logger .debug ("loading embeddings from %s" , str (embeddings_path ))
259235 return np .load (embeddings_path ) # type: ignore[no-any-return]
260236
261- self ._load_model ()
237+ self .embedding_model = self . _load_model ()
262238
263239 logger .debug (
264240 "Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s, prompt=%s" ,
0 commit comments