@@ -29,7 +29,6 @@ class RerankScorer(KNNScorer):
2929
3030 cross_encoder_config: Config of the cross-encoder model used for re-ranking
3131 m: Number of top-ranked neighbors to consider, or None to use k
32- rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None
3332 """
3433
3534 name = "rerank"
@@ -38,9 +37,9 @@ class RerankScorer(KNNScorer):
3837 def __init__ (
3938 self ,
4039 k : int ,
41- weights : WeightType ,
40+ weights : WeightType = "distance" ,
41+ use_crosencoder_scores : bool = False ,
4242 m : int | None = None ,
43- rank_threshold_cutoff : int | None = None ,
4443 cross_encoder_config : CrossEncoderConfig | str | dict [str , Any ] | None = None ,
4544 embedder_config : EmbedderConfig | str | dict [str , Any ] | None = None ,
4645 ) -> None :
@@ -53,18 +52,12 @@ def __init__(
5352 self .cross_encoder_config = CrossEncoderConfig .from_search_config (cross_encoder_config )
5453
5554 self .m = k if m is None else m
56- self .rank_threshold_cutoff = rank_threshold_cutoff
55+ self .use_crosencoder_scores = use_crosencoder_scores
5756
5857 if self .m < 0 or not isinstance (self .m , int ):
5958 msg = "`m` argument of `RerankScorer` must be a positive int"
6059 raise ValueError (msg )
6160
62- if self .rank_threshold_cutoff is not None and (
63- self .rank_threshold_cutoff < 0 or not isinstance (self .rank_threshold_cutoff , int )
64- ):
65- msg = "`rank_threshold_cutoff` argument of `RerankScorer` must be a positive int or None"
66- raise ValueError (msg )
67-
6861 @classmethod
6962 def from_context (
7063 cls ,
@@ -74,7 +67,7 @@ def from_context(
7467 m : PositiveInt | None = None ,
7568 cross_encoder_config : CrossEncoderConfig | str | None = None ,
7669 embedder_config : EmbedderConfig | str | None = None ,
77- rank_threshold_cutoff : int | None = None ,
70+ use_crosencoder_scores : bool = False ,
7871 ) -> "RerankScorer" :
7972 """Create a RerankScorer instance from a given context.
8073
@@ -86,7 +79,7 @@ def from_context(
8679 embedder_config: Config of the embedder used for vectorization,
8780 or None to use the best existing embedder
8881 m: Number of top-ranked neighbors to consider, or None to use k
89- rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None
82+ use_crosencoder_scores: use crosencoder scores for the output probability vector computation
9083 """
9184 if embedder_config is None :
9285 embedder_config = context .resolve_embedder ()
@@ -98,7 +91,7 @@ def from_context(
9891 k = k ,
9992 weights = weights ,
10093 m = m ,
101- rank_threshold_cutoff = rank_threshold_cutoff ,
94+ use_crosencoder_scores = use_crosencoder_scores ,
10295 embedder_config = embedder_config ,
10396 cross_encoder_config = cross_encoder_config ,
10497 )
@@ -113,9 +106,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
113106 if hasattr (self , "_scorer" ):
114107 self .clear_cache ()
115108
116- self ._scorer = Ranker (
117- self .cross_encoder_config ,
118- )
109+ self ._scorer = Ranker (self .cross_encoder_config , output_range = "tanh" )
119110 self ._scorer .fit (utterances , labels )
120111
121112 super ().fit (utterances , labels , clear_cache = False )
@@ -147,10 +138,14 @@ def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[s
147138 ):
148139 cur_ranks = self ._scorer .rank (query , query_docs , top_k = self .m )
149140
150- for dst , src in zip (
151- [labels , distances , neighbours ], [query_labels , query_distances , query_docs ], strict = True
152- ):
141+ for dst , src in zip ([labels , neighbours ], [query_labels , query_docs ], strict = True ):
153142 dst .append ([src [rank ["corpus_id" ]] for rank in cur_ranks ]) # type: ignore[attr-defined]
154143
144+ if self .use_crosencoder_scores :
145+ distances .append ([rank ["score" ] for rank in cur_ranks ])
146+ else :
147+ distances .append ([query_distances [rank ["corpus_id" ]] for rank in cur_ranks ])
148+
155149 scores = self ._count_scores (np .array (labels ), np .array (distances ))
150+
156151 return scores , neighbours
0 commit comments