11import json
22import logging
3+ import shutil
34from pathlib import Path
45from typing import TypedDict
56
@@ -37,12 +38,19 @@ def __init__(
3738
3839 self .logger = logging .getLogger (__name__ )
3940
40- def delete (self ) -> None :
41+ def clear_ram (self ) -> None :
4142 self .logger .debug ("deleting embedder %s" , self .model_name )
4243 self .embedding_model .cpu ()
4344 del self .embedding_model
4445
46+ def delete (self ) -> None :
47+ self .clear_ram ()
48+ shutil .rmtree (
49+ self .dump_dir , ignore_errors = True
50+ ) # TODO: `ignore_errors=True` is workaround for PermissionError: [WinError 5] Access is denied
51+
4552 def dump (self , path : Path ) -> None :
53+ self .dump_dir = path
4654 metadata = EmbedderDumpMetadata (
4755 batch_size = self .batch_size ,
4856 max_length = self .max_length ,
@@ -53,6 +61,7 @@ def dump(self, path: Path) -> None:
5361 json .dump (metadata , file , indent = 4 )
5462
5563 def load (self , path : Path | str ) -> None :
64+ self .dump_dir = Path (path )
5665 path = Path (path )
5766 with (path / self .metadata_dict_name ).open () as file :
5867 metadata : EmbedderDumpMetadata = json .load (file )
@@ -71,4 +80,9 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
7180 )
7281 if self .max_length is not None :
7382 self .embedding_model .max_seq_length = self .max_length
74- return self .embedding_model .encode (utterances , convert_to_numpy = True , batch_size = self .batch_size ) # type: ignore[return-value]
83+ return self .embedding_model .encode (
84+ utterances ,
85+ convert_to_numpy = True ,
86+ batch_size = self .batch_size ,
87+ normalize_embeddings = True ,
88+ ) # type: ignore[return-value]
0 commit comments