-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathdnnc.py
More file actions
246 lines (194 loc) · 9.49 KB
/
dnnc.py
File metadata and controls
246 lines (194 loc) · 9.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""DNNCScorer class for scoring utterances using deep neural network classifiers (DNNC)."""
import itertools as it
import logging
from typing import Any
import numpy as np
import numpy.typing as npt
from pydantic import PositiveInt
from autointent import Context, Ranker, VectorIndex
from autointent.configs import CrossEncoderConfig, EmbedderConfig
from autointent.custom_types import ListOfLabels
from autointent.modules.abc import BaseScorer
logger = logging.getLogger(__name__)
class DNNCScorer(BaseScorer):
r"""
Scoring module for intent classification using a discriminative nearest neighbor classification (DNNC).
This module uses a Ranker for scoring candidate intents and can optionally
train a logistic regression head on top of cross-encoder features.
.. code-block:: bibtex
@misc{zhang2020discriminativenearestneighborfewshot,
title={Discriminative Nearest Neighbor Few-Shot Intent Detection by Transferring Natural Language Inference},
author={Jian-Guo Zhang and Kazuma Hashimoto and Wenhao Liu and Chien-Sheng Wu and Yao Wan and
Philip S. Yu and Richard Socher and Caiming Xiong},
year={2020},
eprint={2010.13009},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2010.13009},
}
:ivar crossencoder_subdir: Subdirectory for storing the cross-encoder model (`Ranker`).
:ivar model: The model used for scoring, which could be a `Ranker` or a `CrossEncoderWithLogreg`.
:ivar _db_dir: Path to the database directory where the vector index is stored.
:ivar name: Name of the scorer, defaults to "dnnc".
Examples
--------
.. testcode::
from autointent.modules.scoring import DNNCScorer
utterances = ["what is your name?", "how are you?"]
labels = [0, 1]
scorer = DNNCScorer(
cross_encoder_config="cross-encoder/ms-marco-MiniLM-L-6-v2",
embedder_config="sergeyzh/rubert-tiny-turbo",
k=5,
)
scorer.fit(utterances, labels)
test_utterances = ["Hello!", "What's up?"]
scores = scorer.predict(test_utterances)
print(scores) # Outputs similarity scores for the utterances
.. testoutput::
[[0.00013581 0. ]
[0.00030066 0. ]]
"""
name = "dnnc"
_n_classes: int
_vector_index: VectorIndex
_cross_encoder: Ranker
supports_multilabel = False
supports_multiclass = True
def __init__(
self,
k: PositiveInt,
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None,
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
) -> None:
"""
Initialize the DNNCScorer.
:param cross_encoder_config: Config of the cross-encoder model.
:param embedder_config: Config of the embedder model.
:param k: Number of nearest neighbors to retrieve.
"""
self.cross_encoder_config = CrossEncoderConfig.from_search_config(cross_encoder_config)
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
self.k = k
if self.k < 0 or not isinstance(self.k, int):
msg = "`k` argument of `DNNCScorer` must be a positive int"
raise ValueError(msg)
@classmethod
def from_context(
cls,
context: Context,
k: PositiveInt,
cross_encoder_config: CrossEncoderConfig | str | None = None,
embedder_config: EmbedderConfig | str | None = None,
) -> "DNNCScorer":
"""
Create a DNNCScorer instance using a Context object.
:param context: Context containing configurations and utilities.
:param cross_encoder_config: Config of the cross-encoder model.
:param k: Number of nearest neighbors to retrieve.
:param embedder_config: Config of the embedder model, or None to use the best embedder.
:return: Initialized DNNCScorer instance.
"""
if embedder_config is None:
embedder_config = context.resolve_embedder()
if cross_encoder_config is None:
cross_encoder_config = context.resolve_ranker()
return cls(
k=k,
embedder_config=embedder_config,
cross_encoder_config=cross_encoder_config,
)
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
"""
Fit the scorer by training or loading the vector index and optionally training a logistic regression head.
:param utterances: List of training utterances.
:param labels: List of labels corresponding to the utterances.
:raises ValueError: If the vector index mismatches the provided utterances.
"""
if hasattr(self, "_vector_index"):
self.clear_cache()
self._validate_task(labels)
self._vector_index = VectorIndex(self.embedder_config)
self._vector_index.add(utterances, labels)
self._cross_encoder = Ranker(self.cross_encoder_config)
self._cross_encoder.fit(utterances, labels)
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
"""
Predict class scores for the given utterances.
:param utterances: List of utterances to score.
:return: Array of predicted scores.
"""
return self._predict(utterances)[0]
def predict_with_metadata(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[dict[str, Any]] | None]:
"""
Predict class scores along with metadata for the given utterances.
:param utterances: List of utterances to score.
:return: Tuple of scores and metadata containing neighbor details and scores.
"""
scores, neighbors, neighbors_scores = self._predict(utterances)
metadata = [
{"neighbors": utterance_neighbors, "scores": utterance_neighbors_scores}
for utterance_neighbors, utterance_neighbors_scores in zip(neighbors, neighbors_scores, strict=True)
]
return scores, metadata
def _get_cross_encoder_scores(self, utterances: list[str], candidates: list[list[str]]) -> list[list[float]]:
"""
Compute cross-encoder scores for utterances against their candidate neighbors.
:param utterances: List of query utterances.
:param candidates: List of candidate utterances for each query.
:return: List of cross-encoder scores for each query-candidate pair.
:raises ValueError: If the number of utterances and candidates do not match.
"""
if len(utterances) != len(candidates):
msg = "Number of utterances doesn't match number of retrieved candidates"
logger.error(msg)
raise ValueError(msg)
text_pairs = [[(query, cand) for cand in docs] for query, docs in zip(utterances, candidates, strict=False)]
flattened_text_pairs = list(it.chain.from_iterable(text_pairs))
if len(flattened_text_pairs) != len(utterances) * len(candidates[0]):
msg = "Number of candidates for each query utterance cannot vary"
logger.error(msg)
raise ValueError(msg)
flattened_cross_encoder_scores: npt.NDArray[np.float64] = self._cross_encoder.predict(flattened_text_pairs)
return [
flattened_cross_encoder_scores[i : i + self.k].tolist() # type: ignore[misc]
for i in range(0, len(flattened_cross_encoder_scores), self.k)
]
def _build_result(self, scores: list[list[float]], labels: list[ListOfLabels]) -> npt.NDArray[Any]:
"""
Build a result matrix with scores assigned to the best neighbor's class.
:param scores: for each query utterance, cross encoder scores of its k closest utterances
:param labels: corresponding intent labels
:return: (n_queries, n_classes) matrix with zeros everywhere except the class of the best neighbor utterance
"""
return build_result(np.array(scores), np.array(labels), self._n_classes)
def clear_cache(self) -> None:
"""Clear cached data in memory used by the vector index."""
self._vector_index.clear_ram()
def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]], list[list[float]]]:
"""
Predict class scores for the given utterances using the vector index and cross-encoder.
:param utterances: List of query utterances.
:return: Tuple containing class scores, neighbor utterances, and neighbor scores.
"""
labels, _, neighbors = self._vector_index.query(
utterances,
self.k,
)
cross_encoder_scores = self._get_cross_encoder_scores(utterances, neighbors)
return self._build_result(cross_encoder_scores, labels), neighbors, cross_encoder_scores
def build_result(scores: npt.NDArray[Any], labels: npt.NDArray[Any], n_classes: int) -> npt.NDArray[Any]:
"""
Build a result matrix with scores assigned to the best neighbor's class.
:param scores: Cross-encoder scores for each query's neighbors.
:param labels: Labels corresponding to each neighbor.
:param n_classes: Total number of classes.
:return: Matrix of size (n_queries, n_classes) with scores for the best class.
"""
res = np.zeros((len(scores), n_classes))
best_neighbors = np.argmax(scores, axis=1)
idx_helper = np.arange(len(res))
best_classes = labels[idx_helper, best_neighbors]
best_scores = scores[idx_helper, best_neighbors]
res[idx_helper, best_classes] = best_scores
return res