Skip to content

Commit 53d1606

Browse files
authored
RerankSorer implementation based on KNN (#50)
1 parent 3380de8 commit 53d1606

File tree

9 files changed

+322
-19
lines changed

9 files changed

+322
-19
lines changed

autointent/modules/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
TunablePredictor,
1212
)
1313
from .retrieval import RetrievalModule, VectorDBModule
14-
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, ScoringModule
14+
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, ScoringModule
1515

1616
T = TypeVar("T", bound=Module)
1717

@@ -25,7 +25,7 @@ def create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
2525
RETRIEVAL_MODULES_MULTILABEL = RETRIEVAL_MODULES_MULTICLASS
2626

2727
SCORING_MODULES_MULTICLASS: dict[str, type[ScoringModule]] = create_modules_dict(
28-
[DNNCScorer, KNNScorer, LinearScorer, DescriptionScorer],
28+
[DNNCScorer, KNNScorer, LinearScorer, DescriptionScorer, RerankScorer]
2929
)
3030

3131
SCORING_MODULES_MULTILABEL: dict[str, type[ScoringModule]] = create_modules_dict(
@@ -58,6 +58,7 @@ def create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
5858
"Module",
5959
"PredictionModule",
6060
"RegExp",
61+
"RerankScorer",
6162
"RetrievalModule",
6263
"ScoringModule",
6364
"ThresholdPredictor",
Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
from ._base import ScoringModule
22
from ._description import DescriptionScorer
33
from ._dnnc import DNNCScorer
4-
from ._knn import KNNScorer
4+
from ._knn import KNNScorer, RerankScorer
55
from ._linear import LinearScorer
66
from ._mlknn import MLKnnScorer
77

8-
__all__ = ["DNNCScorer", "DescriptionScorer", "KNNScorer", "LinearScorer", "MLKnnScorer", "ScoringModule"]
8+
__all__ = [
9+
"DNNCScorer",
10+
"DescriptionScorer",
11+
"KNNScorer",
12+
"LinearScorer",
13+
"MLKnnScorer",
14+
"RerankScorer",
15+
"ScoringModule",
16+
]
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .knn import KNNScorer
2+
from .rerank_scorer import RerankScorer
23

3-
__all__ = ["KNNScorer"]
4+
__all__ = ["KNNScorer", "RerankScorer"]

autointent/modules/scoring/_knn/knn.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class KNNScorer(ScoringModule):
5151
_vector_index: VectorIndex
5252
name = "knn"
5353
prebuilt_index: bool = False
54+
max_length: int | None
5455

5556
def __init__(
5657
self,
@@ -193,13 +194,7 @@ def dump(self, path: str) -> None:
193194
194195
:param path: Path to the directory where assets will be dumped.
195196
"""
196-
self.metadata = KNNScorerDumpMetadata(
197-
db_dir=self.db_dir,
198-
n_classes=self.n_classes,
199-
multilabel=self.multilabel,
200-
batch_size=self.batch_size,
201-
max_length=self.max_length,
202-
)
197+
self.metadata = self._store_state_to_metadata()
203198

204199
dump_dir = Path(path)
205200

@@ -208,6 +203,15 @@ def dump(self, path: str) -> None:
208203

209204
self._vector_index.dump(dump_dir)
210205

206+
def _store_state_to_metadata(self) -> KNNScorerDumpMetadata:
207+
return KNNScorerDumpMetadata(
208+
db_dir=self.db_dir,
209+
n_classes=self.n_classes,
210+
multilabel=self.multilabel,
211+
batch_size=self.batch_size,
212+
max_length=self.max_length,
213+
)
214+
211215
def load(self, path: str) -> None:
212216
"""
213217
Load the KNNScorer's metadata and vector index from disk.
@@ -219,24 +223,35 @@ def load(self, path: str) -> None:
219223
with (dump_dir / self.metadata_dict_name).open() as file:
220224
self.metadata: KNNScorerDumpMetadata = json.load(file)
221225

222-
self.n_classes = self.metadata["n_classes"]
223-
self.multilabel = self.metadata["multilabel"]
226+
self._restore_state_from_metadata(self.metadata)
227+
228+
def _restore_state_from_metadata(self, metadata: KNNScorerDumpMetadata) -> None:
229+
self.n_classes = metadata["n_classes"]
230+
self.multilabel = metadata["multilabel"]
224231

225232
vector_index_client = VectorIndexClient(
226233
device=self.device,
227-
db_dir=self.metadata["db_dir"],
228-
embedder_batch_size=self.metadata["batch_size"],
229-
embedder_max_length=self.metadata["max_length"],
234+
db_dir=metadata["db_dir"],
235+
embedder_batch_size=metadata["batch_size"],
236+
embedder_max_length=metadata["max_length"],
230237
)
231238
self._vector_index = vector_index_client.get_index(self.embedder_name)
232239

240+
def _get_neighbours(
241+
self, utterances: list[str]
242+
) -> tuple[list[list[LabelType]], list[list[float]], list[list[str]]]:
243+
return self._vector_index.query(utterances, self.k)
244+
245+
def _count_scores(self, labels: npt.NDArray[Any], distances: npt.NDArray[Any]) -> npt.NDArray[Any]:
246+
return apply_weights(labels, distances, self.weights, self.n_classes, self.multilabel)
247+
233248
def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]:
234249
"""
235250
Predict class probabilities and retrieve neighbors for the given utterances.
236251
237252
:param utterances: List of query utterances.
238253
:return: Tuple containing class probabilities and neighbor utterances.
239254
"""
240-
labels, distances, neighbors = self._vector_index.query(utterances, self.k)
241-
scores = apply_weights(np.array(labels), np.array(distances), self.weights, self.n_classes, self.multilabel)
255+
labels, distances, neighbors = self._get_neighbours(utterances)
256+
scores = self._count_scores(np.array(labels), np.array(distances))
242257
return scores, neighbors
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""RerankScorer class for re-ranking based on cross-encoder scoring."""
2+
3+
import json
4+
from pathlib import Path
5+
from typing import Any
6+
7+
import numpy as np
8+
import numpy.typing as npt
9+
from sentence_transformers import CrossEncoder
10+
from torch.nn import Sigmoid
11+
from typing_extensions import Self
12+
13+
from autointent.context import Context
14+
from autointent.custom_types import WEIGHT_TYPES, LabelType
15+
16+
from .knn import KNNScorer, KNNScorerDumpMetadata
17+
18+
19+
class RerankScorerDumpMetadata(KNNScorerDumpMetadata):
20+
"""
21+
Metadata for dumping the state of a RerankScorer.
22+
23+
:ivar cross_encoder_name: Name of the cross-encoder model used.
24+
:ivar m: Number of top-ranked neighbors to consider, or None to use k.
25+
:ivar rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None.
26+
"""
27+
28+
cross_encoder_name: str
29+
m: int | None
30+
rank_threshold_cutoff: int | None
31+
32+
33+
class RerankScorer(KNNScorer):
34+
"""
35+
Re-ranking scorer using a cross-encoder for intent classification.
36+
37+
This module uses a cross-encoder to re-rank the nearest neighbors retrieved by a KNN scorer.
38+
39+
:ivar name: Name of the scorer, defaults to "rerank".
40+
:ivar _scorer: CrossEncoder instance for re-ranking.
41+
"""
42+
43+
name = "rerank"
44+
_scorer: CrossEncoder
45+
46+
def __init__(
47+
self,
48+
embedder_name: str,
49+
k: int,
50+
weights: WEIGHT_TYPES,
51+
cross_encoder_name: str,
52+
m: int | None = None,
53+
rank_threshold_cutoff: int | None = None,
54+
db_dir: str | None = None,
55+
device: str = "cpu",
56+
batch_size: int = 32,
57+
max_length: int | None = None,
58+
) -> None:
59+
"""
60+
Initialize the RerankScorer.
61+
62+
:param embedder_name: Name of the embedder used for vectorization.
63+
:param k: Number of closest neighbors to consider during inference.
64+
:param weights: Weighting strategy:
65+
- "uniform" (or False): Equal weight for all neighbors.
66+
- "distance" (or True): Weight inversely proportional to distance.
67+
- "closest": Only the closest neighbor of each class is weighted.
68+
:param cross_encoder_name: Name of the cross-encoder model used for re-ranking.
69+
:param m: Number of top-ranked neighbors to consider, or None to use k.
70+
:param rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None.
71+
:param db_dir: Path to the database directory, or None to use default.
72+
:param device: Device to run operations on, e.g., "cpu" or "cuda".
73+
:param batch_size: Batch size for embedding generation, defaults to 32.
74+
:param max_length: Maximum sequence length for embedding, or None for default.
75+
"""
76+
super().__init__(
77+
embedder_name=embedder_name,
78+
k=k,
79+
weights=weights,
80+
db_dir=db_dir,
81+
device=device,
82+
batch_size=batch_size,
83+
max_length=max_length,
84+
)
85+
86+
self.cross_encoder_name = cross_encoder_name
87+
self.m = k if m is None else m
88+
self.rank_threshold_cutoff = rank_threshold_cutoff
89+
90+
@classmethod
91+
def from_context(
92+
cls,
93+
context: Context,
94+
k: int,
95+
weights: WEIGHT_TYPES,
96+
cross_encoder_name: str,
97+
embedder_name: str | None = None,
98+
m: int | None = None,
99+
rank_threshold_cutoff: int | None = None,
100+
) -> Self:
101+
"""
102+
Create a RerankScorer instance from a given context.
103+
104+
:param context: Context object containing optimization information and vector index client.
105+
:param k: Number of closest neighbors to consider during inference.
106+
:param weights: Weighting strategy.
107+
:param cross_encoder_name: Name of the cross-encoder model used for re-ranking.
108+
:param embedder_name: Name of the embedder used for vectorization, or None to use the best existing embedder.
109+
:param m: Number of top-ranked neighbors to consider, or None to use k.
110+
:param rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None.
111+
:return: An instance of RerankScorer.
112+
"""
113+
if embedder_name is None:
114+
embedder_name = context.optimization_info.get_best_embedder()
115+
prebuilt_index = True
116+
else:
117+
prebuilt_index = context.vector_index_client.exists(embedder_name)
118+
119+
instance = cls(
120+
embedder_name=embedder_name,
121+
k=k,
122+
weights=weights,
123+
cross_encoder_name=cross_encoder_name,
124+
m=m,
125+
rank_threshold_cutoff=rank_threshold_cutoff,
126+
db_dir=str(context.get_db_dir()),
127+
device=context.get_device(),
128+
batch_size=context.get_batch_size(),
129+
max_length=context.get_max_length(),
130+
)
131+
# TODO: needs re-thinking....
132+
instance.prebuilt_index = prebuilt_index
133+
return instance
134+
135+
def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
136+
"""
137+
Fit the RerankScorer with utterances and labels.
138+
139+
:param utterances: List of utterances to fit the scorer.
140+
:param labels: List of labels corresponding to the utterances.
141+
"""
142+
self._scorer = CrossEncoder(self.cross_encoder_name, device=self.device, max_length=self.max_length) # type: ignore[arg-type]
143+
144+
super().fit(utterances, labels)
145+
146+
def _store_state_to_metadata(self) -> RerankScorerDumpMetadata:
147+
"""
148+
Store the current state of the RerankScorer to metadata.
149+
150+
:return: Metadata containing the current state of the RerankScorer.
151+
"""
152+
return RerankScorerDumpMetadata(
153+
**super()._store_state_to_metadata(),
154+
m=self.m,
155+
cross_encoder_name=self.cross_encoder_name,
156+
rank_threshold_cutoff=self.rank_threshold_cutoff,
157+
)
158+
159+
def load(self, path: str) -> None:
160+
"""
161+
Load the RerankScorer from a given path.
162+
163+
:param path: Path to the directory containing the dumped metadata.
164+
"""
165+
dump_dir = Path(path)
166+
167+
with (dump_dir / self.metadata_dict_name).open() as file:
168+
self.metadata: RerankScorerDumpMetadata = json.load(file)
169+
170+
self._restore_state_from_metadata(self.metadata)
171+
172+
def _restore_state_from_metadata(self, metadata: RerankScorerDumpMetadata) -> None:
173+
"""
174+
Restore the state of the RerankScorer from metadata.
175+
176+
:param metadata: Metadata containing the state of the RerankScorer.
177+
"""
178+
super()._restore_state_from_metadata(metadata)
179+
180+
self.m = metadata["m"] if metadata["m"] else self.k
181+
self.cross_encoder_name = metadata["cross_encoder_name"]
182+
self.rank_threshold_cutoff = metadata["rank_threshold_cutoff"]
183+
self._scorer = CrossEncoder(self.cross_encoder_name, device=self.device, max_length=self.max_length) # type: ignore[arg-type]
184+
185+
def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]:
186+
"""
187+
Predict the scores and neighbors for given utterances.
188+
189+
:param utterances: List of utterances to predict scores for.
190+
:return: A tuple containing the scores and neighbors.
191+
"""
192+
knn_labels, knn_distances, knn_neighbors = self._get_neighbours(utterances)
193+
194+
labels: list[list[LabelType]] = []
195+
distances: list[list[float]] = []
196+
neighbours: list[list[str]] = []
197+
198+
for query, query_labels, query_distances, query_docs in zip(
199+
utterances, knn_labels, knn_distances, knn_neighbors, strict=True
200+
):
201+
cur_ranks = self._scorer.rank(
202+
query, query_docs, top_k=self.m, batch_size=self.batch_size, activation_fct=Sigmoid()
203+
)
204+
205+
for dst, src in zip(
206+
[labels, distances, neighbours], [query_labels, query_distances, query_docs], strict=True
207+
):
208+
dst.append([src[rank["corpus_id"]] for rank in cur_ranks]) # type: ignore[attr-defined, index]
209+
210+
scores = self._count_scores(np.array(labels), np.array(distances))
211+
return scores, neighbours

tests/assets/configs/multiclass.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ nodes:
2020
- avsolatorio/GIST-small-Embedding-v0
2121
k: [1, 3]
2222
train_head: [false, true]
23+
- module_type: rerank
24+
k: [ 5, 10 ]
25+
weights: [uniform, distance, closest]
26+
m: [ 2, 3 ]
27+
cross_encoder_name:
28+
- cross-encoder/ms-marco-MiniLM-L-6-v2
2329
- node_type: prediction
2430
metric: prediction_accuracy
2531
search_space:

tests/assets/configs/multilabel.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ nodes:
1616
- module_type: linear
1717
- module_type: mlknn
1818
k: [5]
19+
- module_type: rerank
20+
k: [ 5, 10 ]
21+
weights: [ uniform, distance, closest ]
22+
m: [ 2, 3 ]
23+
cross_encoder_name:
24+
- cross-encoder/ms-marco-MiniLM-L-6-v2
1925
- node_type: prediction
2026
metric: prediction_accuracy
2127
search_space:

0 commit comments

Comments
 (0)