Skip to content

Commit 6545b83

Browse files
committed
feat: added arg similarity
1 parent 5c66698 commit 6545b83

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

autointent/_embedder.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class EmbedderDumpMetadata(TypedDict):
5050
"""Maximum sequence length for the embedding model."""
5151
use_cache: bool
5252
"""Whether to use embeddings caching."""
53+
similarity_fn_name: str | None
54+
"""Name of the similarity function to use."""
5355

5456

5557
class Embedder:
@@ -73,7 +75,10 @@ def __init__(self, embedder_config: EmbedderConfig) -> None:
7375
self.config = embedder_config
7476

7577
self.embedding_model = SentenceTransformer(
76-
self.config.model_name, device=self.config.device, prompts=embedder_config.get_prompt_config()
78+
self.config.model_name,
79+
device=self.config.device,
80+
prompts=embedder_config.get_prompt_config(),
81+
similarity_fn_name=self.config.similarity_fn_name,
7782
)
7883

7984
self._logger = logging.getLogger(__name__)
@@ -116,6 +121,7 @@ def dump(self, path: Path) -> None:
116121
batch_size=self.config.batch_size,
117122
max_length=self.config.tokenizer_config.max_length,
118123
use_cache=self.config.use_cache,
124+
similarity_fn_name=self.config.similarity_fn_name,
119125
)
120126
path.mkdir(parents=True, exist_ok=True)
121127
with (path / self._metadata_dict_name).open("w") as file:
@@ -186,3 +192,17 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
186192
np.save(embeddings_path, embeddings)
187193

188194
return embeddings
195+
196+
def similarity(
197+
self, embeddings1: npt.NDArray[np.float32], embeddings2: npt.NDArray[np.float32]
198+
) -> npt.NDArray[np.float32]:
199+
"""Calculate similarity between two sets of embeddings.
200+
201+
Args:
202+
embeddings1: First set of embeddings.
203+
embeddings2: Second set of embeddings.
204+
205+
Returns:
206+
A numpy array of similarities.
207+
"""
208+
return self.embedding_model.similarity(embeddings1, embeddings2)

autointent/configs/_transformers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class EmbedderConfig(HFModelConfig):
6060
sts_prompt: str | None = Field(None, description="Prompt for finding most similar sentences.")
6161
query_prompt: str | None = Field(None, description="Prompt for query.")
6262
passage_prompt: str | None = Field(None, description="Prompt for passage.")
63+
similarity_fn_name: str | None = Field(
64+
"cosine", description="Name of the similarity function to use (cosine, dot, euclidean, manhattan)."
65+
)
6366

6467
def get_prompt_config(self) -> dict[str, str] | None:
6568
"""Get the prompt config for the given prompt type.

autointent/modules/scoring/_description/description.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import scipy
77
from numpy.typing import NDArray
88
from pydantic import PositiveFloat
9-
from sklearn.metrics.pairwise import cosine_similarity
109

1110
from autointent import Context, Embedder, Ranker
1211
from autointent.configs import CrossEncoderConfig, EmbedderConfig, TaskTypeEnum
@@ -159,7 +158,7 @@ def predict(self, utterances: list[str]) -> NDArray[np.float64]:
159158
"""
160159
if self._encoder_type == "bi":
161160
utterance_vectors = self._embedder.embed(utterances, TaskTypeEnum.sts)
162-
similarities: NDArray[np.float64] = cosine_similarity(utterance_vectors, self._description_vectors)
161+
similarities: NDArray[np.float64] = self._embedder.similarity(utterance_vectors, self._description_vectors)
163162
else:
164163
pairs = [(utterance, description) for utterance in utterances for description in self._description_texts]
165164

0 commit comments

Comments
 (0)