@@ -30,6 +30,7 @@ def __init__( # noqa: PLR0913
3030 embedder_name : str ,
3131 k : int ,
3232 weights : WEIGHT_TYPES ,
33+ use_crosencoder_scores : bool = False ,
3334 m : int | None = None ,
3435 rank_threshold_cutoff : int | None = None ,
3536 embedder_device : str = "cpu" ,
@@ -50,6 +51,7 @@ def __init__( # noqa: PLR0913
5051 - "uniform" (or False): Equal weight for all neighbors.
5152 - "distance" (or True): Weight inversely proportional to distance.
5253 - "closest": Only the closest neighbor of each class is weighted.
54+ :param use_crosencoder_scores: use crosencoder scores for the output probability vector computation
5355 :param cross_encoder_name: Name of the cross-encoder model used for re-ranking.
5456 :param m: Number of top-ranked neighbors to consider, or None to use k.
5557 :param rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None.
@@ -75,6 +77,7 @@ def __init__( # noqa: PLR0913
7577
7678 self .m = k if m is None else m
7779 self .rank_threshold_cutoff = rank_threshold_cutoff
80+ self .use_crosencoder_scores = use_crosencoder_scores
7881
7982 @classmethod
8083 def from_context (
@@ -87,6 +90,7 @@ def from_context(
8790 embedder_name : str | None = None ,
8891 m : int | None = None ,
8992 rank_threshold_cutoff : int | None = None ,
93+ use_crosencoder_scores : bool = False ,
9094 ) -> "RerankScorer" :
9195 """
9296 Create a RerankScorer instance from a given context.
@@ -98,6 +102,7 @@ def from_context(
98102 :param embedder_name: Name of the embedder used for vectorization, or None to use the best existing embedder.
99103 :param m: Number of top-ranked neighbors to consider, or None to use k.
100104 :param rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None.
105+ :param use_crosencoder_scores: use crosencoder scores for the output probability vector computation
101106 :return: An instance of RerankScorer.
102107 """
103108 if embedder_name is None :
@@ -107,6 +112,7 @@ def from_context(
107112 k = k ,
108113 weights = weights ,
109114 m = m ,
115+ use_crosencoder_scores = use_crosencoder_scores ,
110116 rank_threshold_cutoff = rank_threshold_cutoff ,
111117 train_head = train_head ,
112118 embedder_name = embedder_name ,
@@ -156,10 +162,14 @@ def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[s
156162 ):
157163 cur_ranks = self ._scorer .rank (query , query_docs , top_k = self .m )
158164
159- for dst , src in zip (
160- [labels , distances , neighbours ], [query_labels , query_distances , query_docs ], strict = True
161- ):
165+ for dst , src in zip ([labels , neighbours ], [query_labels , query_docs ], strict = True ):
162166 dst .append ([src [rank ["corpus_id" ]] for rank in cur_ranks ]) # type: ignore[attr-defined]
163167
168+ if self .use_crosencoder_scores :
169+ distances .append ([rank ["score" ] for rank in cur_ranks ])
170+ else :
171+ distances .append ([query_distances [rank ["corpus_id" ]] for rank in cur_ranks ])
172+
164173 scores = self ._count_scores (np .array (labels ), np .array (distances ))
174+
165175 return scores , neighbours
0 commit comments