@@ -50,6 +50,8 @@ class EmbedderDumpMetadata(TypedDict):
5050 """Maximum sequence length for the embedding model."""
5151 use_cache : bool
5252 """Whether to use embeddings caching."""
53+ similarity_fn_name : str | None
54+ """Name of the similarity function to use."""
5355
5456
5557class Embedder :
@@ -73,7 +75,10 @@ def __init__(self, embedder_config: EmbedderConfig) -> None:
7375 self .config = embedder_config
7476
7577 self .embedding_model = SentenceTransformer (
76- self .config .model_name , device = self .config .device , prompts = embedder_config .get_prompt_config ()
78+ self .config .model_name ,
79+ device = self .config .device ,
80+ prompts = embedder_config .get_prompt_config (),
81+ similarity_fn_name = self .config .similarity_fn_name ,
7782 )
7883
7984 self ._logger = logging .getLogger (__name__ )
@@ -116,6 +121,7 @@ def dump(self, path: Path) -> None:
116121 batch_size = self .config .batch_size ,
117122 max_length = self .config .tokenizer_config .max_length ,
118123 use_cache = self .config .use_cache ,
124+ similarity_fn_name = self .config .similarity_fn_name ,
119125 )
120126 path .mkdir (parents = True , exist_ok = True )
121127 with (path / self ._metadata_dict_name ).open ("w" ) as file :
@@ -186,3 +192,17 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
186192 np .save (embeddings_path , embeddings )
187193
188194 return embeddings
195+
196+ def similarity (
197+ self , embeddings1 : npt .NDArray [np .float32 ], embeddings2 : npt .NDArray [np .float32 ]
198+ ) -> npt .NDArray [np .float32 ]:
199+ """Calculate similarity between two sets of embeddings.
200+
201+ Args:
202+ embeddings1: First set of embeddings.
203+ embeddings2: Second set of embeddings.
204+
205+ Returns:
206+ A numpy array of similarities.
207+ """
208+ return self .embedding_model .similarity (embeddings1 , embeddings2 )
0 commit comments