Skip to content

Commit 520b7f8

Browse files
Darinochkagithub-actions[bot]Samoedvoorhs
authored
feat: added crossencoder (#181)
* feat: added crossencoder * refactor * feat: added arg similarity * Update optimizer_config.schema.json * feat: added tests * feat: added errors * fix: scoring test * fix: description vectors error * fix: description vectors error * fix: lint * fix: test * add node validators (#177) * add node validators * add comments * Update optimizer_config.schema.json * rename bert model * lint * fixes * fix test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]> * fix: unit tests * feat: added test for description * feat: delete encoder_type from the class args * feat: update assets * feat: update assets * fix: fixed test * Update optimizer_config.schema.json --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Roman Solomatin <[email protected]> Co-authored-by: voorhs <[email protected]>
1 parent 57b0b6a commit 520b7f8

File tree

8 files changed

+190
-22
lines changed

8 files changed

+190
-22
lines changed

autointent/_embedder.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class EmbedderDumpMetadata(TypedDict):
5050
"""Maximum sequence length for the embedding model."""
5151
use_cache: bool
5252
"""Whether to use embeddings caching."""
53+
similarity_fn_name: str | None
54+
"""Name of the similarity function to use."""
5355

5456

5557
class Embedder:
@@ -76,6 +78,7 @@ def __init__(self, embedder_config: EmbedderConfig) -> None:
7678
self.config.model_name,
7779
device=self.config.device,
7880
prompts=embedder_config.get_prompt_config(),
81+
similarity_fn_name=self.config.similarity_fn_name,
7982
trust_remote_code=self.config.trust_remote_code,
8083
)
8184

@@ -119,6 +122,7 @@ def dump(self, path: Path) -> None:
119122
batch_size=self.config.batch_size,
120123
max_length=self.config.tokenizer_config.max_length,
121124
use_cache=self.config.use_cache,
125+
similarity_fn_name=self.config.similarity_fn_name,
122126
)
123127
path.mkdir(parents=True, exist_ok=True)
124128
with (path / self._metadata_dict_name).open("w") as file:
@@ -189,3 +193,18 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
189193
np.save(embeddings_path, embeddings)
190194

191195
return embeddings
196+
197+
def similarity(
198+
self, embeddings1: npt.NDArray[np.float32], embeddings2: npt.NDArray[np.float32]
199+
) -> npt.NDArray[np.float32]:
200+
"""Calculate similarity between two sets of embeddings.
201+
202+
Args:
203+
embeddings1: First set of embeddings.
204+
embeddings2: Second set of embeddings.
205+
206+
Returns:
207+
A numpy array of similarities.
208+
"""
209+
result = self.embedding_model.similarity(embeddings1, embeddings2)
210+
return result.detach().cpu().numpy().astype(np.float32)

autointent/configs/_transformers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class EmbedderConfig(HFModelConfig):
6161
sts_prompt: str | None = Field(None, description="Prompt for finding most similar sentences.")
6262
query_prompt: str | None = Field(None, description="Prompt for query.")
6363
passage_prompt: str | None = Field(None, description="Prompt for passage.")
64+
similarity_fn_name: str | None = Field(
65+
"cosine", description="Name of the similarity function to use (cosine, dot, euclidean, manhattan)."
66+
)
6467

6568
def get_prompt_config(self) -> dict[str, str] | None:
6669
"""Get the prompt config for the given prompt type.

autointent/modules/scoring/_description/description.py

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
"""DescriptionScorer class for scoring utterances based on intent descriptions."""
22

3-
from typing import Any
3+
from typing import Any, Literal
44

55
import numpy as np
66
import scipy
77
from numpy.typing import NDArray
88
from pydantic import PositiveFloat
9-
from sklearn.metrics.pairwise import cosine_similarity
109

11-
from autointent import Context, Embedder
12-
from autointent.configs import EmbedderConfig, TaskTypeEnum
10+
from autointent import Context, Embedder, Ranker
11+
from autointent.configs import CrossEncoderConfig, EmbedderConfig, TaskTypeEnum
1312
from autointent.context.optimization_info import ScorerArtifact
1413
from autointent.custom_types import ListOfLabels
1514
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
@@ -19,29 +18,38 @@
1918
class DescriptionScorer(BaseScorer):
2019
"""Scoring module that scores utterances based on similarity to intent descriptions.
2120
22-
DescriptionScorer embeds both the utterances and the intent descriptions, then computes
23-
a similarity score between the two, using either cosine similarity and softmax.
21+
DescriptionScorer can use either a bi-encoder or cross-encoder architecture:
22+
- Bi-encoder: Embeds both utterances and descriptions separately, then computes cosine similarity
23+
- Cross-encoder: Directly computes similarity between each utterance-description pair
2424
2525
Args:
26-
embedder_config: Config of the embedder model
26+
embedder_config: Config of the embedder model (for bi-encoder mode)
27+
cross_encoder_config: Config of the cross-encoder model (for cross-encoder mode)
28+
encoder_type: Type of encoder to use, either "bi" or "cross"
2729
temperature: Temperature parameter for scaling logits, defaults to 1.0
2830
"""
2931

30-
_embedder: Embedder
32+
_embedder: Embedder | None = None
33+
_cross_encoder: Ranker | None = None
3134
name = "description"
3235
_n_classes: int
3336
_multilabel: bool
34-
_description_vectors: NDArray[Any]
37+
_description_vectors: NDArray[Any] | None = None
38+
_description_texts: list[str] | None = None
3539
supports_multiclass = True
3640
supports_multilabel = True
3741

3842
def __init__(
3943
self,
4044
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
45+
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None,
46+
encoder_type: Literal["bi", "cross"] = "bi",
4147
temperature: PositiveFloat = 1.0,
4248
) -> None:
4349
self.temperature = temperature
4450
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
51+
self.cross_encoder_config = CrossEncoderConfig.from_search_config(cross_encoder_config)
52+
self._encoder_type = encoder_type
4553

4654
if self.temperature < 0 or not isinstance(self.temperature, float | int):
4755
msg = "`temperature` argument of `DescriptionScorer` must be a positive float"
@@ -51,35 +59,51 @@ def __init__(
5159
def from_context(
5260
cls,
5361
context: Context,
54-
temperature: PositiveFloat,
62+
temperature: PositiveFloat = 1.0,
5563
embedder_config: EmbedderConfig | str | None = None,
64+
cross_encoder_config: CrossEncoderConfig | str | None = None,
65+
encoder_type: Literal["bi", "cross"] = "bi",
5666
) -> "DescriptionScorer":
5767
"""Create a DescriptionScorer instance using a Context object.
5868
5969
Args:
6070
context: Context containing configurations and utilities
6171
temperature: Temperature parameter for scaling logits
6272
embedder_config: Config of the embedder model. If None, the best embedder is used
73+
cross_encoder_config: Config of the cross-encoder model. If None, the default config is used
74+
encoder_type: Type of encoder to use, either "bi" or "cross"
6375
6476
Returns:
6577
Initialized DescriptionScorer instance
6678
"""
6779
if embedder_config is None:
6880
embedder_config = context.resolve_embedder()
81+
if cross_encoder_config is None:
82+
cross_encoder_config = context.resolve_ranker()
6983

7084
return cls(
7185
temperature=temperature,
7286
embedder_config=embedder_config,
87+
cross_encoder_config=cross_encoder_config,
88+
encoder_type=encoder_type,
7389
)
7490

7591
def get_embedder_config(self) -> dict[str, Any]:
76-
"""Get the name of the embedder.
92+
"""Get the configuration of the embedder.
7793
7894
Returns:
79-
Embedder name
95+
Embedder configuration
8096
"""
8197
return self.embedder_config.model_dump()
8298

99+
def get_cross_encoder_config(self) -> dict[str, Any]:
100+
"""Get the configuration of the cross-encoder.
101+
102+
Returns:
103+
Cross-encoder configuration
104+
"""
105+
return self.cross_encoder_config.model_dump()
106+
83107
def fit(
84108
self,
85109
utterances: list[str],
@@ -96,8 +120,10 @@ def fit(
96120
Raises:
97121
ValueError: If descriptions contain None values or embeddings mismatch utterances
98122
"""
99-
if hasattr(self, "_embedder"):
123+
if hasattr(self, "_embedder") and self._embedder is not None:
100124
self._embedder.clear_ram()
125+
if hasattr(self, "_cross_encoder") and self._cross_encoder is not None:
126+
self._cross_encoder.clear_ram()
101127

102128
self._validate_task(labels)
103129

@@ -108,10 +134,17 @@ def fit(
108134
)
109135
raise ValueError(error_text)
110136

111-
embedder = Embedder(self.embedder_config)
112-
113-
self._description_vectors = embedder.embed(descriptions, TaskTypeEnum.sts)
114-
self._embedder = embedder
137+
if self._encoder_type == "bi":
138+
embedder = Embedder(self.embedder_config)
139+
self._description_vectors = embedder.embed(descriptions, TaskTypeEnum.sts)
140+
self._embedder = embedder
141+
self._cross_encoder = None
142+
self._description_texts = None
143+
else:
144+
self._cross_encoder = Ranker(self.cross_encoder_config)
145+
self._description_texts = descriptions
146+
self._embedder = None
147+
self._description_vectors = None
115148

116149
def predict(self, utterances: list[str]) -> NDArray[np.float64]:
117150
"""Predict scores for utterances based on similarity to intent descriptions.
@@ -122,8 +155,32 @@ def predict(self, utterances: list[str]) -> NDArray[np.float64]:
122155
Returns:
123156
Array of probabilities for each utterance
124157
"""
125-
utterance_vectors = self._embedder.embed(utterances, TaskTypeEnum.sts)
126-
similarities: NDArray[np.float64] = cosine_similarity(utterance_vectors, self._description_vectors)
158+
if self._encoder_type == "bi":
159+
if self._description_vectors is None:
160+
error_text = "Description vectors are not initialized. Call fit() before predict()."
161+
raise RuntimeError(error_text)
162+
163+
if self._embedder is None:
164+
error_text = "Embedder is not initialized. Call fit() before predict()."
165+
raise RuntimeError(error_text)
166+
167+
utterance_vectors = self._embedder.embed(utterances, TaskTypeEnum.sts)
168+
similarities: NDArray[np.float64] = np.array(
169+
self._embedder.similarity(utterance_vectors, self._description_vectors), dtype=np.float64
170+
)
171+
else:
172+
if self._cross_encoder is None:
173+
error_text = "Cross encoder is not initialized. Call fit() before predict()."
174+
raise RuntimeError(error_text)
175+
176+
if self._description_texts is None:
177+
error_text = "Description texts are not initialized. Call fit() before predict()."
178+
raise RuntimeError(error_text)
179+
180+
pairs = [(utterance, description) for utterance in utterances for description in self._description_texts]
181+
182+
scores = self._cross_encoder.predict(pairs)
183+
similarities = np.array(scores, dtype=np.float64).reshape(len(utterances), len(self._description_texts))
127184

128185
if self._multilabel:
129186
probabilities = scipy.special.expit(similarities / self.temperature)
@@ -132,8 +189,11 @@ def predict(self, utterances: list[str]) -> NDArray[np.float64]:
132189
return probabilities # type: ignore[no-any-return]
133190

134191
def clear_cache(self) -> None:
135-
"""Clear cached data in memory used by the embedder."""
136-
self._embedder.clear_ram()
192+
"""Clear cached data in memory used by the embedder or cross-encoder."""
193+
if self._embedder is not None:
194+
self._embedder.clear_ram()
195+
if self._cross_encoder is not None:
196+
self._cross_encoder.clear_ram()
137197

138198
def get_train_data(self, context: Context) -> tuple[list[str], ListOfLabels, list[str]]:
139199
"""Get training data from context.

docs/optimizer_config.schema.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,19 @@
212212
"description": "Prompt for passage.",
213213
"title": "Passage Prompt"
214214
},
215+
"similarity_fn_name": {
216+
"anyOf": [
217+
{
218+
"type": "string"
219+
},
220+
{
221+
"type": "null"
222+
}
223+
],
224+
"default": "cosine",
225+
"description": "Name of the similarity function to use (cosine, dot, euclidean, manhattan).",
226+
"title": "Similarity Fn Name"
227+
},
215228
"use_cache": {
216229
"default": false,
217230
"description": "Whether to use embeddings caching.",
@@ -389,6 +402,7 @@
389402
"sts_prompt": null,
390403
"query_prompt": null,
391404
"passage_prompt": null,
405+
"similarity_fn_name": "cosine",
392406
"use_cache": false
393407
}
394408
},

tests/assets/configs/description.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
search_space:
1212
- module_name: description
1313
temperature: [1.0, 0.5, 0.1, 0.05]
14+
embedder_config:
15+
- model_name: sentence-transformers/all-MiniLM-L6-v2
16+
cross_encoder_config:
17+
- model_name: cross-encoder/ms-marco-MiniLM-L-6-v2
18+
encoder_type: [cross, bi]
1419
- node_type: decision
1520
target_metric: decision_accuracy
1621
search_space:

tests/callback/test_callback.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def test_pipeline_callbacks(dataset):
146146
"query_prompt": None,
147147
"sts_prompt": None,
148148
"use_cache": False,
149+
"similarity_fn_name": "cosine",
149150
"trust_remote_code": False,
150151
},
151152
"k": 1,
@@ -181,6 +182,7 @@ def test_pipeline_callbacks(dataset):
181182
"query_prompt": None,
182183
"sts_prompt": None,
183184
"use_cache": False,
185+
"similarity_fn_name": "cosine",
184186
"trust_remote_code": False,
185187
},
186188
"k": 1,
@@ -216,6 +218,7 @@ def test_pipeline_callbacks(dataset):
216218
"query_prompt": None,
217219
"sts_prompt": None,
218220
"use_cache": False,
221+
"similarity_fn_name": "cosine",
219222
"trust_remote_code": False,
220223
},
221224
},

tests/modules/scoring/test_description.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import tempfile
2+
13
import numpy as np
24
import pytest
35

@@ -45,3 +47,63 @@ def test_description_scorer(dataset, expected_prediction, multilabel):
4547
assert metadata is None
4648

4749
scorer.clear_cache()
50+
51+
52+
@pytest.mark.parametrize(
53+
("expected_prediction", "multilabel"),
54+
[
55+
([[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]], True),
56+
([[0.2, 0.3, 0.2, 0.2], [0.2, 0.3, 0.2, 0.2]], False),
57+
],
58+
)
59+
def test_description_scorer_cross_encoder(dataset, expected_prediction, multilabel):
60+
if multilabel:
61+
dataset = dataset.to_multilabel()
62+
data_handler = DataHandler(dataset)
63+
64+
scorer = DescriptionScorer(
65+
cross_encoder_config="cross-encoder/ms-marco-MiniLM-L-6-v2", encoder_type="cross", temperature=0.3
66+
)
67+
68+
scorer.fit(
69+
data_handler.train_utterances(0),
70+
data_handler.train_labels(0),
71+
data_handler.intent_descriptions,
72+
)
73+
assert scorer._description_texts is not None
74+
assert len(scorer._description_texts) == len(data_handler.intent_descriptions)
75+
assert scorer._cross_encoder is not None
76+
77+
test_utterances = [
78+
"What is the balance on my account?",
79+
"How do I reset my online banking password?",
80+
]
81+
82+
predictions = scorer.predict(test_utterances)
83+
if multilabel:
84+
assert np.sum(predictions) <= len(test_utterances) * 4
85+
else:
86+
np.testing.assert_almost_equal(np.sum(predictions), len(test_utterances))
87+
88+
assert predictions.shape == (len(test_utterances), len(data_handler.intent_descriptions))
89+
np.testing.assert_almost_equal(predictions, np.array(expected_prediction).reshape(predictions.shape), decimal=1)
90+
91+
predictions, metadata = scorer.predict_with_metadata(test_utterances)
92+
assert len(predictions) == len(test_utterances)
93+
assert metadata is None
94+
95+
scorer.clear_cache()
96+
97+
with tempfile.TemporaryDirectory() as temp_dir:
98+
scorer.dump(temp_dir)
99+
100+
new_scorer = DescriptionScorer(
101+
cross_encoder_config="cross-encoder/ms-marco-MiniLM-L-6-v2", encoder_type="cross", temperature=0.3
102+
)
103+
new_scorer.load(temp_dir)
104+
105+
loaded_predictions = new_scorer.predict(test_utterances)
106+
107+
np.testing.assert_almost_equal(predictions, loaded_predictions, decimal=5)
108+
109+
new_scorer.clear_cache()

0 commit comments

Comments
 (0)