99import shutil
1010from functools import lru_cache
1111from pathlib import Path
12- from typing import TypedDict
1312
1413import huggingface_hub
1514import numpy as np
@@ -59,23 +58,6 @@ def _get_latest_commit_hash(model_name: str) -> str:
5958 return commit_hash
6059
6160
62- class EmbedderDumpMetadata (TypedDict ):
63- """Metadata for saving and loading an Embedder instance."""
64-
65- model_name : str
66- """Name of the hugging face model or a local path to sentence transformers dump."""
67- device : str | None
68- """Torch notation for CPU or CUDA."""
69- batch_size : int
70- """Batch size used for embedding calculations."""
71- max_length : int | None
72- """Maximum sequence length for the embedding model."""
73- use_cache : bool
74- """Whether to use embeddings caching."""
75- similarity_fn_name : str | None
76- """Name of the similarity function to use."""
77-
78-
7961class Embedder :
8062 """A wrapper for managing embedding models using :py:class:`sentence_transformers.SentenceTransformer`.
8163
@@ -85,7 +67,6 @@ class Embedder:
8567
8668 _metadata_dict_name : str = "metadata.json"
8769 _dump_dir : Path | None = None
88- embedding_model : SentenceTransformer
8970
9071 def __init__ (self , embedder_config : EmbedderConfig ) -> None :
9172 """Initialize the Embedder.
@@ -106,22 +87,25 @@ def _get_hash(self) -> int:
10687 commit_hash = _get_latest_commit_hash (self .config .model_name )
10788 hasher .update (commit_hash )
10889 else :
109- self ._load_model ()
90+ self .embedding_model = self . _load_model ()
11091 for parameter in self .embedding_model .parameters ():
11192 hasher .update (parameter .detach ().cpu ().numpy ())
11293 hasher .update (self .config .tokenizer_config .max_length )
11394 return hasher .intdigest ()
11495
115- def _load_model (self ) -> None :
96+ def _load_model (self ) -> SentenceTransformer :
11697 """Load sentence transformers model to device."""
11798 if not hasattr (self , "embedding_model" ):
118- self . embedding_model = SentenceTransformer (
99+ res = SentenceTransformer (
119100 self .config .model_name ,
120101 device = self .config .device ,
121102 prompts = self .config .get_prompt_config (),
122103 similarity_fn_name = self .config .similarity_fn_name ,
123104 trust_remote_code = self .config .trust_remote_code ,
124105 )
106+ else :
107+ res = self .embedding_model
108+ return res
125109
126110 def clear_ram (self ) -> None :
127111 """Move the embedding model to CPU and delete it from memory."""
@@ -144,17 +128,9 @@ def dump(self, path: Path) -> None:
144128 path: Path to the directory where the model will be saved.
145129 """
146130 self ._dump_dir = path
147- metadata = EmbedderDumpMetadata (
148- model_name = str (self .config .model_name ),
149- device = self .config .device ,
150- batch_size = self .config .batch_size ,
151- max_length = self .config .tokenizer_config .max_length ,
152- use_cache = self .config .use_cache ,
153- similarity_fn_name = self .config .similarity_fn_name ,
154- )
155131 path .mkdir (parents = True , exist_ok = True )
156132 with (path / self ._metadata_dict_name ).open ("w" ) as file :
157- json .dump (metadata , file , indent = 4 )
133+ json .dump (self . config . model_dump ( mode = "json" ) , file , indent = 4 )
158134
159135 @classmethod
160136 def load (cls , path : Path | str , override_config : EmbedderConfig | None = None ) -> "Embedder" :
@@ -165,12 +141,12 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
165141 override_config: one can override presaved settings
166142 """
167143 with (Path (path ) / cls ._metadata_dict_name ).open (encoding = "utf-8" ) as file :
168- metadata : EmbedderDumpMetadata = json . load (file )
144+ config = EmbedderConfig . model_validate_json (file . read () )
169145
170146 if override_config is not None :
171- kwargs = {** metadata , ** override_config .model_dump (exclude_unset = True )}
147+ kwargs = {** config . model_dump () , ** override_config .model_dump (exclude_unset = True )}
172148 else :
173- kwargs = metadata # type: ignore[assignment]
149+ kwargs = config . model_dump ()
174150
175151 max_length = kwargs .pop ("max_length" , None )
176152 if max_length is not None :
@@ -203,7 +179,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
203179 logger .debug ("loading embeddings from %s" , str (embeddings_path ))
204180 return np .load (embeddings_path ) # type: ignore[no-any-return]
205181
206- self ._load_model ()
182+ self .embedding_model = self . _load_model ()
207183
208184 logger .debug (
209185 "Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s, prompt=%s" ,
0 commit comments