Skip to content

Commit 2bf20ec

Browse files
Dmitryv-2024voorhs
andauthored
Refactored CrossEncoder into our own wrapper class to support head training (#88)
* Refactored CrossEncoder into our own wrapper class to support head training * Fix typo in comment * fixing tests * Fixing mypy errors * Fixing doc build * Still fixing doc build * Keep fixing doc build * minor bug fix * mypy was updated `(-_-)` * change type annotation of `pairs`argument * `_logits_list` -> `_activations_list` * `get_features` -> `_get_features_or_predictions` --------- Co-authored-by: Алексеев Илья <[email protected]> Co-authored-by: voorhs <[email protected]>
1 parent 8a61a6c commit 2bf20ec

File tree

8 files changed

+219
-90
lines changed

8 files changed

+219
-90
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ lint:
2424

2525
.PHONY: sync
2626
sync:
27-
poetry install --sync --with dev,test,lint,typing,docs
27+
poetry sync
2828

2929
.PHONY: docs
3030
docs:
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._nli_transformer import NLITransformer
2+
3+
__all__ = ["NLITransformer"]

autointent/modules/scoring/_dnnc/head_training.py renamed to autointent/_transformers/_nli_transformer.py

Lines changed: 98 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1-
"""CrossEncoderWithLogreg class for cross-encoder-based binary classification with logistic regression."""
1+
"""NLITransformer class for cross-encoder-based estimation of meaning closeness.
2+
3+
Can be used to rank retrieved sentences by meaning closeness to provided utterance.
4+
"""
25

36
import itertools as it
47
import logging
58
from pathlib import Path
69
from random import shuffle
7-
from typing import Any, TypeVar
10+
from typing import Any
811

912
import joblib
1013
import numpy as np
1114
import numpy.typing as npt
1215
import torch
1316
from sentence_transformers import CrossEncoder
1417
from sklearn.linear_model import LogisticRegressionCV
18+
from torch import nn
1519

1620
from autointent.custom_types import LabelType
1721

@@ -54,15 +58,13 @@ def construct_samples(
5458
return pairs, labels
5559

5660

57-
CrossEncoderType = TypeVar("CrossEncoderType", bound="CrossEncoderWithLogreg")
58-
59-
60-
class CrossEncoderWithLogreg:
61+
class NLITransformer:
6162
r"""
62-
Cross-encoder with logistic regression for binary classification.
63+
Cross-encoder for NLI.
6364
64-
This class uses a SentenceTransformers CrossEncoder model to extract features
65-
and LogisticRegressionCV for classification.
65+
In the hart this class uses a SentenceTransformers CrossEncoder model to extract features.
66+
Then it uses either the model's clissifier or our custom trained LogisticRegressionCV
67+
(custom classifier layer in the future) to rank documents using similarity score to the query.
6668
6769
:ivar cross_encoder: The CrossEncoder model used to extract features.
6870
:ivar batch_size: Batch size for processing text pairs.
@@ -72,10 +74,8 @@ class CrossEncoderWithLogreg:
7274
Examples
7375
--------
7476
Creating and fitting the CrossEncoderWithLogreg:
75-
>>> from autointent.modules import CrossEncoderWithLogreg
76-
>>> from sentence_transformers import CrossEncoder
77-
>>> model = CrossEncoder("cross-encoder-model")
78-
>>> scorer = CrossEncoderWithLogreg(model)
77+
>>> from autointent._transformers import NLITransformer
78+
>>> scorer = NLITransformer("cross-encoder-model")
7979
>>> utterances = ["What is your name?", "How old are you?"]
8080
>>> labels = [1, 0]
8181
>>> scorer.fit(utterances, labels)
@@ -87,43 +87,64 @@ class CrossEncoderWithLogreg:
8787
8888
Saving and loading the model:
8989
>>> scorer.save("outputs/")
90-
>>> loaded_scorer = CrossEncoderWithLogreg.load("outputs/")
90+
>>> loaded_scorer = NLITransformer.load("outputs/")
9191
"""
9292

93-
def __init__(self, model: CrossEncoder, batch_size: int = 326) -> None:
93+
def __init__(
94+
self,
95+
model: str,
96+
device: str = "cpu",
97+
train_classifier: bool = False,
98+
batch_size: int = 326,
99+
max_length: int | None = None,
100+
classifier_head: LogisticRegressionCV | None = None,
101+
) -> None:
94102
"""
95-
Initialize the CrossEncoderWithLogreg.
103+
Initialize the NLITransformer.
96104
97-
:param model: The CrossEncoder model to use.
105+
:param model: The CrossEncoder model name to use.
106+
:param device: Device to run operations on, e.g., "cpu" or "cuda".
107+
:param train_classifier: Whether to train a custom classifier, defaults to False.
98108
:param batch_size: Batch size for processing text pairs, defaults to 326.
109+
:param max_length (int, optional): Max length for input sequences for the cross encoder.
110+
:param classifier_head (LogisticRegressionCV, optional): Classifier (to be used in restore procedure mainly).
99111
"""
100-
self.cross_encoder = model
112+
self.cross_encoder = CrossEncoder(model, trust_remote_code=True, device=device, max_length=max_length) # type: ignore[arg-type]
113+
self.train_classifier = False
101114
self.batch_size = batch_size
115+
self.max_length = max_length
116+
self._clf = classifier_head
117+
118+
if classifier_head is not None or train_classifier:
119+
self.train_classifier = True
120+
self._activations_list: list[npt.NDArray[Any]] = []
121+
self._hook_handler = self.cross_encoder.model.classifier.register_forward_hook(self._classifier_hook)
122+
123+
def _classifier_hook(self, _module, input_tensor, _output_tensor) -> None: # type: ignore[no-untyped-def] # noqa: ANN001
124+
self._activations_list.append(input_tensor[0].cpu().numpy())
102125

103126
@torch.no_grad()
104-
def get_features(self, pairs: list[list[str]]) -> npt.NDArray[Any]:
127+
def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDArray[Any]:
105128
"""
106-
Extract features from text pairs using the CrossEncoder model.
129+
Extract features or get predictions using the CrossEncoder model.
130+
131+
If :py:attr:`~train_classifier` is ``True``, return raw activations from
132+
cross-encoder transformer. Otherwise, get predictions from cross-encoder head.
107133
108134
:param pairs: List of text pairs.
109135
:return: Numpy array of extracted features.
110136
"""
111-
logits_list: list[npt.NDArray[Any]] = []
112-
113-
def hook_function(module, input_tensor, output_tensor) -> None: # type: ignore[no-untyped-def] # noqa: ARG001, ANN001
114-
logits_list.append(input_tensor[0].cpu().numpy())
137+
if not self.train_classifier:
138+
return np.array(self.cross_encoder.predict(pairs, batch_size=self.batch_size, activation_fct=nn.Sigmoid()))
115139

116-
handler = self.cross_encoder.model.classifier.register_forward_hook(hook_function)
140+
# put the data through, features will be taken in the hook
141+
self.cross_encoder.predict(pairs, batch_size=self.batch_size)
117142

118-
for i in range(0, len(pairs), self.batch_size):
119-
batch = pairs[i : i + self.batch_size]
120-
self.cross_encoder.predict(batch)
143+
res = np.concatenate(self._activations_list, axis=0)
144+
self._activations_list.clear()
145+
return res # type: ignore[no-any-return]
121146

122-
handler.remove()
123-
124-
return np.concatenate(logits_list, axis=0)
125-
126-
def _fit(self, pairs: list[list[str]], labels: list[LabelType]) -> None:
147+
def _fit(self, pairs: list[tuple[str, str]], labels: list[LabelType]) -> None:
127148
"""
128149
Train the logistic regression model on cross-encoder features.
129150
@@ -137,8 +158,10 @@ def _fit(self, pairs: list[list[str]], labels: list[LabelType]) -> None:
137158
logger.error(msg)
138159
raise ValueError(msg)
139160

140-
features = self.get_features(pairs)
161+
features = self._get_features_or_predictions(pairs)
141162

163+
# TODO: LogisticRegressionCV has class_weight="balanced". Is it better to use it instead of balance_factor in
164+
# construct_samples?
142165
clf = LogisticRegressionCV()
143166
clf.fit(features, labels)
144167

@@ -151,18 +174,53 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
151174
:param utterances: List of utterances (texts).
152175
:param labels: Intent class labels corresponding to the utterances.
153176
"""
177+
if not self.train_classifier:
178+
return # do nothing if the classifier is not to be re-trained
179+
154180
pairs, labels_ = construct_samples(utterances, labels, balancing_factor=1)
155181
self._fit(pairs, labels_) # type: ignore[arg-type]
156182

157-
def predict(self, pairs: list[list[str]]) -> npt.NDArray[Any]:
183+
def predict(self, pairs: list[tuple[str, str]]) -> npt.NDArray[Any]:
158184
"""
159185
Predict probabilities of two utterances having the same intent label.
160186
161187
:param pairs: List of text pairs to classify.
162188
:return: Numpy array of probabilities.
163189
"""
164-
features = self.get_features(pairs)
165-
return self._clf.predict_proba(features)[:, 1] # type: ignore[no-any-return]
190+
if self.train_classifier and self._clf is None:
191+
msg = "Classifier is not trained yet"
192+
raise ValueError(msg)
193+
194+
features = self._get_features_or_predictions(pairs)
195+
196+
if self._clf is not None:
197+
return np.array(self._clf.predict_proba(features)[:, 1])
198+
199+
return features
200+
201+
def rank(
202+
self,
203+
query: str,
204+
query_docs: list[str],
205+
top_k: int | None = None,
206+
) -> list[dict[str, Any]]:
207+
"""
208+
Rank documents according to meaning closeness to the query.
209+
210+
:param query: The reference document.
211+
:query_docs: List of documents to rank
212+
:top_k: how many document to return
213+
:return: array of dictionaries of ranked items.
214+
"""
215+
query_doc_pairs = [(query, doc) for doc in query_docs]
216+
scores = self.predict(query_doc_pairs)
217+
218+
if top_k is None:
219+
top_k = len(query_docs)
220+
221+
results = [{"corpus_id": i, "score": scores[i]} for i in range(len(query_docs))]
222+
results.sort(key=lambda x: x["score"], reverse=True)
223+
return results[:top_k]
166224

167225
def save(self, path: str) -> None:
168226
"""
@@ -178,21 +236,13 @@ def save(self, path: str) -> None:
178236
clf_path = dump_dir / "classifier.joblib"
179237
joblib.dump(self._clf, clf_path)
180238

181-
def set_classifier(self, clf: LogisticRegressionCV) -> None:
182-
"""
183-
Set the logistic regression classifier.
184-
185-
:param clf: LogisticRegressionCV instance.
186-
"""
187-
self._clf = clf
188-
189239
@classmethod
190-
def load(cls, path: str) -> "CrossEncoderWithLogreg":
240+
def load(cls, path: str) -> "NLITransformer":
191241
"""
192242
Load the model and classifier from disk.
193243
194244
:param path: Directory path containing the saved model and classifier.
195-
:return: Initialized CrossEncoderWithLogreg instance.
245+
:return: Initialized NLITransformer instance.
196246
"""
197247
dump_dir = Path(path)
198248

@@ -202,9 +252,5 @@ def load(cls, path: str) -> "CrossEncoderWithLogreg":
202252

203253
# Load sentence transformer model
204254
crossencoder_dir = str(dump_dir / "crossencoder")
205-
model = CrossEncoder(crossencoder_dir)
206-
207-
res = cls(model)
208-
res.set_classifier(clf)
209255

210-
return res
256+
return cls(crossencoder_dir, classifier_head=clf)

autointent/context/data_handler/_data_handler.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,16 @@ class DataHandler:
3030
"""Data handler class."""
3131

3232
def __init__(
33-
self,
34-
dataset: Dataset,
35-
force_multilabel: bool = False,
36-
random_seed: int = 0,
33+
self, dataset: Dataset, force_multilabel: bool = False, random_seed: int = 0, split_train: bool = True
3734
) -> None:
3835
"""
3936
Initialize the data handler.
4037
4138
:param dataset: Training dataset.
4239
:param force_multilabel: If True, force the dataset to be multilabel.
4340
:param random_seed: Seed for random number generation.
41+
:param split_train: Perform or not splitting of train (default to split to be used in scoring and
42+
threshold search).
4443
"""
4544
set_seed(random_seed)
4645

@@ -50,7 +49,7 @@ def __init__(
5049

5150
self.n_classes = self.dataset.n_classes
5251

53-
self._split(random_seed)
52+
self._split(random_seed, split_train)
5453

5554
self.regexp_patterns = [
5655
RegexPatterns(
@@ -191,11 +190,11 @@ def dump(self, filepath: str | Path) -> None:
191190
"""
192191
self.dataset.to_json(filepath)
193192

194-
def _split(self, random_seed: int) -> None:
193+
def _split(self, random_seed: int, split_train: bool) -> None:
195194
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)
196195
has_test_split = any(split.startswith(Split.TEST) for split in self.dataset)
197196

198-
if Split.TRAIN in self.dataset:
197+
if split_train and Split.TRAIN in self.dataset:
199198
self._split_train(random_seed)
200199

201200
if Split.TEST not in self.dataset:
@@ -252,13 +251,21 @@ def _split_validation_from_test(self, random_seed: int) -> None:
252251
)
253252

254253
def _split_validation_from_train(self, random_seed: int) -> None:
255-
for idx in range(2):
256-
self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset(
254+
if Split.TRAIN in self.dataset:
255+
self.dataset[Split.TRAIN], self.dataset[Split.VALIDATION] = split_dataset(
257256
self.dataset,
258-
split=f"{Split.TRAIN}_{idx}",
257+
split=Split.TRAIN,
259258
test_size=0.2,
260259
random_seed=random_seed,
261260
)
261+
else:
262+
for idx in range(2):
263+
self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset(
264+
self.dataset,
265+
split=f"{Split.TRAIN}_{idx}",
266+
test_size=0.2,
267+
random_seed=random_seed,
268+
)
262269

263270
def _split_test(self, test_size: float, random_seed: int) -> None:
264271
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TEST}_0"] = split_dataset(

0 commit comments

Comments
 (0)