Skip to content

Commit 812940c

Browse files
Dmitryv-2024voorhs
andauthored
Rerank scorer: опция для выбора источника для расчета вектора вероятностей (#115)
* Enable rerank scorer to use crossencoder scores for the probability vector * add cross encoder scores range options * upd test --------- Co-authored-by: voorhs <[email protected]>
1 parent 54098c0 commit 812940c

File tree

6 files changed

+27
-24
lines changed

6 files changed

+27
-24
lines changed

autointent/_ranker.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import logging
1111
from pathlib import Path
1212
from random import shuffle
13-
from typing import Any, TypedDict
13+
from typing import Any, Literal, TypedDict
1414

1515
import joblib
1616
import numpy as np
@@ -101,12 +101,14 @@ def __init__(
101101
self,
102102
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any],
103103
classifier_head: LogisticRegressionCV | None = None,
104+
output_range: Literal["sigmoid", "tanh"] = "sigmoid",
104105
) -> None:
105106
"""Initialize the Ranker.
106107
107108
Args:
108109
cross_encoder_config: Configuration for the cross-encoder model
109110
classifier_head: Optional pre-trained classifier head
111+
output_range: Range of the output probabilities ([0, 1] for sigmoid, [-1, 1] for tanh)
110112
"""
111113
self.config = CrossEncoderConfig.from_search_config(cross_encoder_config)
112114
self.cross_encoder = st.CrossEncoder(
@@ -117,6 +119,7 @@ def __init__(
117119
)
118120
self._train_head = False
119121
self._clf = classifier_head
122+
self.output_range = output_range
120123

121124
if classifier_head is not None or self.config.train_head:
122125
self._train_head = True
@@ -148,7 +151,7 @@ def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDAr
148151
self.cross_encoder.predict(
149152
pairs,
150153
batch_size=self.config.batch_size,
151-
activation_fct=nn.Sigmoid(),
154+
activation_fct=nn.Sigmoid() if self.output_range == "sigmoid" else nn.Tanh(),
152155
)
153156
)
154157

@@ -210,7 +213,10 @@ def predict(self, pairs: list[tuple[str, str]]) -> npt.NDArray[Any]:
210213
features = self._get_features_or_predictions(pairs)
211214

212215
if self._clf is not None:
213-
return np.array(self._clf.predict_proba(features)[:, 1])
216+
probs = np.array(self._clf.predict_proba(features)[:, 1])
217+
if self.output_range == "tanh":
218+
probs = probs * 2 - 1
219+
return probs
214220
return features
215221

216222
def rank(

autointent/modules/scoring/_dnnc/dnnc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
119119
self._vector_index = VectorIndex(self.embedder_config)
120120
self._vector_index.add(utterances, labels)
121121

122-
self._cross_encoder = Ranker(self.cross_encoder_config)
122+
self._cross_encoder = Ranker(self.cross_encoder_config, output_range="sigmoid")
123123
self._cross_encoder.fit(utterances, labels)
124124

125125
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:

autointent/modules/scoring/_knn/rerank_scorer.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class RerankScorer(KNNScorer):
2929
3030
cross_encoder_config: Config of the cross-encoder model used for re-ranking
3131
m: Number of top-ranked neighbors to consider, or None to use k
32-
rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None
3332
"""
3433

3534
name = "rerank"
@@ -38,9 +37,9 @@ class RerankScorer(KNNScorer):
3837
def __init__(
3938
self,
4039
k: int,
41-
weights: WeightType,
40+
weights: WeightType = "distance",
41+
use_crosencoder_scores: bool = False,
4242
m: int | None = None,
43-
rank_threshold_cutoff: int | None = None,
4443
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None,
4544
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
4645
) -> None:
@@ -53,18 +52,12 @@ def __init__(
5352
self.cross_encoder_config = CrossEncoderConfig.from_search_config(cross_encoder_config)
5453

5554
self.m = k if m is None else m
56-
self.rank_threshold_cutoff = rank_threshold_cutoff
55+
self.use_crosencoder_scores = use_crosencoder_scores
5756

5857
if self.m < 0 or not isinstance(self.m, int):
5958
msg = "`m` argument of `RerankScorer` must be a positive int"
6059
raise ValueError(msg)
6160

62-
if self.rank_threshold_cutoff is not None and (
63-
self.rank_threshold_cutoff < 0 or not isinstance(self.rank_threshold_cutoff, int)
64-
):
65-
msg = "`rank_threshold_cutoff` argument of `RerankScorer` must be a positive int or None"
66-
raise ValueError(msg)
67-
6861
@classmethod
6962
def from_context(
7063
cls,
@@ -74,7 +67,7 @@ def from_context(
7467
m: PositiveInt | None = None,
7568
cross_encoder_config: CrossEncoderConfig | str | None = None,
7669
embedder_config: EmbedderConfig | str | None = None,
77-
rank_threshold_cutoff: int | None = None,
70+
use_crosencoder_scores: bool = False,
7871
) -> "RerankScorer":
7972
"""Create a RerankScorer instance from a given context.
8073
@@ -86,7 +79,7 @@ def from_context(
8679
embedder_config: Config of the embedder used for vectorization,
8780
or None to use the best existing embedder
8881
m: Number of top-ranked neighbors to consider, or None to use k
89-
rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None
82+
use_crosencoder_scores: use crosencoder scores for the output probability vector computation
9083
"""
9184
if embedder_config is None:
9285
embedder_config = context.resolve_embedder()
@@ -98,7 +91,7 @@ def from_context(
9891
k=k,
9992
weights=weights,
10093
m=m,
101-
rank_threshold_cutoff=rank_threshold_cutoff,
94+
use_crosencoder_scores=use_crosencoder_scores,
10295
embedder_config=embedder_config,
10396
cross_encoder_config=cross_encoder_config,
10497
)
@@ -113,9 +106,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
113106
if hasattr(self, "_scorer"):
114107
self.clear_cache()
115108

116-
self._scorer = Ranker(
117-
self.cross_encoder_config,
118-
)
109+
self._scorer = Ranker(self.cross_encoder_config, output_range="tanh")
119110
self._scorer.fit(utterances, labels)
120111

121112
super().fit(utterances, labels, clear_cache=False)
@@ -147,10 +138,14 @@ def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[s
147138
):
148139
cur_ranks = self._scorer.rank(query, query_docs, top_k=self.m)
149140

150-
for dst, src in zip(
151-
[labels, distances, neighbours], [query_labels, query_distances, query_docs], strict=True
152-
):
141+
for dst, src in zip([labels, neighbours], [query_labels, query_docs], strict=True):
153142
dst.append([src[rank["corpus_id"]] for rank in cur_ranks]) # type: ignore[attr-defined]
154143

144+
if self.use_crosencoder_scores:
145+
distances.append([rank["score"] for rank in cur_ranks])
146+
else:
147+
distances.append([query_distances[rank["corpus_id"]] for rank in cur_ranks])
148+
155149
scores = self._count_scores(np.array(labels), np.array(distances))
150+
156151
return scores, neighbours

tests/assets/configs/multiclass.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
k: [ 5, 10 ]
2424
weights: [uniform, distance, closest]
2525
m: [ 2, 3 ]
26+
use_crosencoder_scores: [true, false]
2627
cross_encoder_config:
2728
- cross-encoder/ms-marco-MiniLM-L-6-v2
2829
- module_name: sklearn

tests/assets/configs/multilabel.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
- module_name: rerank
1919
k: [ 5, 10 ]
2020
weights: [ uniform, distance, closest ]
21+
use_crosencoder_scores: [true, false]
2122
m: [ 2, 3 ]
2223
cross_encoder_config:
2324
- model_name: cross-encoder/ms-marco-MiniLM-L-6-v2

tests/configs/test_scoring.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def valid_scoring_config():
4444
"embedder_config": ["sergeyzh/rubert-tiny-turbo"],
4545
"k": [5],
4646
"weights": ["distance"],
47-
"rank_threshold_cutoff": [None, 3],
47+
"use_crosencoder_scores": [True, False],
4848
},
4949
{
5050
"module_name": "sklearn",

0 commit comments

Comments
 (0)