Skip to content

Commit e300080

Browse files
committed
Enable rerank scorer to use crossencoder scores for the probability vector
1 parent a05b530 commit e300080

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

autointent/modules/scoring/_knn/rerank_scorer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__( # noqa: PLR0913
3030
embedder_name: str,
3131
k: int,
3232
weights: WEIGHT_TYPES,
33+
use_crosencoder_scores: bool = False,
3334
m: int | None = None,
3435
rank_threshold_cutoff: int | None = None,
3536
embedder_device: str = "cpu",
@@ -50,6 +51,7 @@ def __init__( # noqa: PLR0913
5051
- "uniform" (or False): Equal weight for all neighbors.
5152
- "distance" (or True): Weight inversely proportional to distance.
5253
- "closest": Only the closest neighbor of each class is weighted.
54+
:param use_crosencoder_scores: use crosencoder scores for the output probability vector computation
5355
:param cross_encoder_name: Name of the cross-encoder model used for re-ranking.
5456
:param m: Number of top-ranked neighbors to consider, or None to use k.
5557
:param rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None.
@@ -75,6 +77,7 @@ def __init__( # noqa: PLR0913
7577

7678
self.m = k if m is None else m
7779
self.rank_threshold_cutoff = rank_threshold_cutoff
80+
self.use_crosencoder_scores = use_crosencoder_scores
7881

7982
@classmethod
8083
def from_context(
@@ -87,6 +90,7 @@ def from_context(
8790
embedder_name: str | None = None,
8891
m: int | None = None,
8992
rank_threshold_cutoff: int | None = None,
93+
use_crosencoder_scores: bool = False,
9094
) -> "RerankScorer":
9195
"""
9296
Create a RerankScorer instance from a given context.
@@ -98,6 +102,7 @@ def from_context(
98102
:param embedder_name: Name of the embedder used for vectorization, or None to use the best existing embedder.
99103
:param m: Number of top-ranked neighbors to consider, or None to use k.
100104
:param rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None.
105+
:param use_crosencoder_scores: use crosencoder scores for the output probability vector computation
101106
:return: An instance of RerankScorer.
102107
"""
103108
if embedder_name is None:
@@ -107,6 +112,7 @@ def from_context(
107112
k=k,
108113
weights=weights,
109114
m=m,
115+
use_crosencoder_scores=use_crosencoder_scores,
110116
rank_threshold_cutoff=rank_threshold_cutoff,
111117
train_head=train_head,
112118
embedder_name=embedder_name,
@@ -156,10 +162,14 @@ def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[s
156162
):
157163
cur_ranks = self._scorer.rank(query, query_docs, top_k=self.m)
158164

159-
for dst, src in zip(
160-
[labels, distances, neighbours], [query_labels, query_distances, query_docs], strict=True
161-
):
165+
for dst, src in zip([labels, neighbours], [query_labels, query_docs], strict=True):
162166
dst.append([src[rank["corpus_id"]] for rank in cur_ranks]) # type: ignore[attr-defined]
163167

168+
if self.use_crosencoder_scores:
169+
distances.append([rank["score"] for rank in cur_ranks])
170+
else:
171+
distances.append([query_distances[rank["corpus_id"]] for rank in cur_ranks])
172+
164173
scores = self._count_scores(np.array(labels), np.array(distances))
174+
165175
return scores, neighbours

tests/assets/configs/multiclass.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
k: [ 5, 10 ]
3030
weights: [uniform, distance, closest]
3131
m: [ 2, 3 ]
32+
use_crosencoder_scores: [true, false]
3233
cross_encoder_name:
3334
- cross-encoder/ms-marco-MiniLM-L-6-v2
3435
- node_type: decision

tests/assets/configs/multilabel.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
- module_name: rerank
2525
k: [ 5, 10 ]
2626
weights: [ uniform, distance, closest ]
27+
use_crosencoder_scores: [true, false]
2728
m: [ 2, 3 ]
2829
cross_encoder_name:
2930
- cross-encoder/ms-marco-MiniLM-L-6-v2

0 commit comments

Comments
 (0)