@@ -37,10 +37,16 @@ def get_embeddings_path(filename: str) -> Path:
3737class EmbedderDumpMetadata (TypedDict ):
3838 """Metadata for saving and loading an Embedder instance."""
3939
40+ model_name_or_path : str
41+ """Name of the hugging face model or a local path to sentence transformers dump."""
42+ device : str
43+ """Torch notation for CPU or CUDA."""
4044 batch_size : int
4145 """Batch size used for embedding calculations."""
4246 max_length : int | None
4347 """Maximum sequence length for the embedding model."""
48+ use_cache : bool
49+ """Whether to use embeddings caching."""
4450
4551
4652class Embedder :
@@ -51,12 +57,11 @@ class Embedder:
5157 embedding models, as well as calculating embeddings for input texts.
5258 """
5359
54- embedder_subdir : str = "sentence_transformers"
5560 metadata_dict_name : str = "metadata.json"
5661
5762 def __init__ (
5863 self ,
59- model_name : str | Path ,
64+ model_name_or_path : str | Path ,
6065 device : str = "cpu" ,
6166 batch_size : int = 32 ,
6267 max_length : int | None = None ,
@@ -71,16 +76,13 @@ def __init__(
7176 :param max_length: Maximum sequence length for the embedding model.
7277 :param use_cache: Flag indicating whether to cache intermediate embeddings.
7378 """
74- self .model_name = model_name
79+ self .model_name = model_name_or_path
7580 self .device = device
7681 self .batch_size = batch_size
7782 self .max_length = max_length
7883 self .use_cache = use_cache
7984
80- if Path (model_name ).exists ():
81- self .load (model_name )
82- else :
83- self .embedding_model = SentenceTransformer (str (model_name ), device = device )
85+ self .embedding_model = SentenceTransformer (str (model_name_or_path ), device = device )
8486
8587 self .logger = logging .getLogger (__name__ )
8688
@@ -105,10 +107,7 @@ def clear_ram(self) -> None:
105107 def delete (self ) -> None :
106108 """Delete the embedding model and its associated directory."""
107109 self .clear_ram ()
108- shutil .rmtree (
109- self .dump_dir ,
110- ignore_errors = True ,
111- ) # TODO: `ignore_errors=True` is workaround for PermissionError: [WinError 5] Access is denied
110+ shutil .rmtree (self .dump_dir )
112111
113112 def dump (self , path : Path ) -> None :
114113 """
@@ -118,28 +117,35 @@ def dump(self, path: Path) -> None:
118117 """
119118 self .dump_dir = path
120119 metadata = EmbedderDumpMetadata (
120+ model_name_or_path = str (self .model_name ),
121+ device = self .device ,
121122 batch_size = self .batch_size ,
122123 max_length = self .max_length ,
124+ use_cache = self .use_cache ,
123125 )
124126 path .mkdir (parents = True , exist_ok = True )
125- self .embedding_model .save (str (path / self .embedder_subdir ))
126127 with (path / self .metadata_dict_name ).open ("w" ) as file :
127128 json .dump (metadata , file , indent = 4 )
128129
129- def load (self , path : Path | str ) -> None :
130+ @classmethod
131+ def load (
132+ cls , path : Path | str , batch_size : int | None = None , use_cache : bool | None = None , device : str | None = None
133+ ) -> "Embedder" :
130134 """
131135 Load the embedding model and metadata from disk.
132136
133137 :param path: Path to the directory where the model is stored.
134138 """
135- self .dump_dir = Path (path )
136- path = Path (path )
137- with (path / self .metadata_dict_name ).open () as file :
139+ with (Path (path ) / cls .metadata_dict_name ).open () as file :
138140 metadata : EmbedderDumpMetadata = json .load (file )
139- self .batch_size = metadata ["batch_size" ]
140- self .max_length = metadata ["max_length" ]
141141
142- self .embedding_model = SentenceTransformer (str (path / self .embedder_subdir ), device = self .device )
142+ return cls (
143+ model_name_or_path = metadata ["model_name_or_path" ],
144+ device = device or metadata ["device" ],
145+ batch_size = batch_size or metadata ["batch_size" ],
146+ max_length = metadata ["max_length" ],
147+ use_cache = use_cache or metadata ["use_cache" ],
148+ )
143149
144150 def embed (self , utterances : list [str ]) -> npt .NDArray [np .float32 ]:
145151 """
@@ -179,4 +185,4 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
179185 embeddings_path .parent .mkdir (parents = True , exist_ok = True )
180186 np .save (embeddings_path , embeddings )
181187
182- return embeddings # type: ignore[return-value]
188+ return embeddings
0 commit comments