1- """CrossEncoderWithLogreg class for cross-encoder-based binary classification with logistic regression."""
1+ """NLITransformer class for cross-encoder-based estimation of meaning closeness.
2+
3+ Can be used to rank retrieved sentences by meaning closeness to provided utterance.
4+ """
25
36import itertools as it
47import logging
58from pathlib import Path
69from random import shuffle
7- from typing import Any , TypeVar
10+ from typing import Any
811
912import joblib
1013import numpy as np
1114import numpy .typing as npt
1215import torch
1316from sentence_transformers import CrossEncoder
1417from sklearn .linear_model import LogisticRegressionCV
18+ from torch import nn
1519
1620from autointent .custom_types import LabelType
1721
@@ -54,15 +58,13 @@ def construct_samples(
5458 return pairs , labels
5559
5660
57- CrossEncoderType = TypeVar ("CrossEncoderType" , bound = "CrossEncoderWithLogreg" )
58-
59-
60- class CrossEncoderWithLogreg :
61+ class NLITransformer :
6162 r"""
62- Cross-encoder with logistic regression for binary classification .
63+ Cross-encoder for NLI .
6364
64- This class uses a SentenceTransformers CrossEncoder model to extract features
65- and LogisticRegressionCV for classification.
65+ In the hart this class uses a SentenceTransformers CrossEncoder model to extract features.
66+ Then it uses either the model's clissifier or our custom trained LogisticRegressionCV
67+ (custom classifier layer in the future) to rank documents using similarity score to the query.
6668
6769 :ivar cross_encoder: The CrossEncoder model used to extract features.
6870 :ivar batch_size: Batch size for processing text pairs.
@@ -72,10 +74,8 @@ class CrossEncoderWithLogreg:
7274 Examples
7375 --------
7476 Creating and fitting the CrossEncoderWithLogreg:
75- >>> from autointent.modules import CrossEncoderWithLogreg
76- >>> from sentence_transformers import CrossEncoder
77- >>> model = CrossEncoder("cross-encoder-model")
78- >>> scorer = CrossEncoderWithLogreg(model)
77+ >>> from autointent._transformers import NLITransformer
78+ >>> scorer = NLITransformer("cross-encoder-model")
7979 >>> utterances = ["What is your name?", "How old are you?"]
8080 >>> labels = [1, 0]
8181 >>> scorer.fit(utterances, labels)
@@ -87,43 +87,64 @@ class CrossEncoderWithLogreg:
8787
8888 Saving and loading the model:
8989 >>> scorer.save("outputs/")
90- >>> loaded_scorer = CrossEncoderWithLogreg .load("outputs/")
90+ >>> loaded_scorer = NLITransformer .load("outputs/")
9191 """
9292
93- def __init__ (self , model : CrossEncoder , batch_size : int = 326 ) -> None :
93+ def __init__ (
94+ self ,
95+ model : str ,
96+ device : str = "cpu" ,
97+ train_classifier : bool = False ,
98+ batch_size : int = 326 ,
99+ max_length : int | None = None ,
100+ classifier_head : LogisticRegressionCV | None = None ,
101+ ) -> None :
94102 """
95- Initialize the CrossEncoderWithLogreg .
103+ Initialize the NLITransformer .
96104
97- :param model: The CrossEncoder model to use.
105+ :param model: The CrossEncoder model name to use.
106+ :param device: Device to run operations on, e.g., "cpu" or "cuda".
107+ :param train_classifier: Whether to train a custom classifier, defaults to False.
98108 :param batch_size: Batch size for processing text pairs, defaults to 326.
109+ :param max_length (int, optional): Max length for input sequences for the cross encoder.
110+ :param classifier_head (LogisticRegressionCV, optional): Classifier (to be used in restore procedure mainly).
99111 """
100- self .cross_encoder = model
112+ self .cross_encoder = CrossEncoder (model , trust_remote_code = True , device = device , max_length = max_length ) # type: ignore[arg-type]
113+ self .train_classifier = False
101114 self .batch_size = batch_size
115+ self .max_length = max_length
116+ self ._clf = classifier_head
117+
118+ if classifier_head is not None or train_classifier :
119+ self .train_classifier = True
120+ self ._activations_list : list [npt .NDArray [Any ]] = []
121+ self ._hook_handler = self .cross_encoder .model .classifier .register_forward_hook (self ._classifier_hook )
122+
123+ def _classifier_hook (self , _module , input_tensor , _output_tensor ) -> None : # type: ignore[no-untyped-def] # noqa: ANN001
124+ self ._activations_list .append (input_tensor [0 ].cpu ().numpy ())
102125
103126 @torch .no_grad ()
104- def get_features (self , pairs : list [list [ str ]]) -> npt .NDArray [Any ]:
127+ def _get_features_or_predictions (self , pairs : list [tuple [ str , str ]]) -> npt .NDArray [Any ]:
105128 """
106- Extract features from text pairs using the CrossEncoder model.
129+ Extract features or get predictions using the CrossEncoder model.
130+
131+ If :py:attr:`~train_classifier` is ``True``, return raw activations from
132+ cross-encoder transformer. Otherwise, get predictions from cross-encoder head.
107133
108134 :param pairs: List of text pairs.
109135 :return: Numpy array of extracted features.
110136 """
111- logits_list : list [npt .NDArray [Any ]] = []
112-
113- def hook_function (module , input_tensor , output_tensor ) -> None : # type: ignore[no-untyped-def] # noqa: ARG001, ANN001
114- logits_list .append (input_tensor [0 ].cpu ().numpy ())
137+ if not self .train_classifier :
138+ return np .array (self .cross_encoder .predict (pairs , batch_size = self .batch_size , activation_fct = nn .Sigmoid ()))
115139
116- handler = self .cross_encoder .model .classifier .register_forward_hook (hook_function )
140+ # put the data through, features will be taken in the hook
141+ self .cross_encoder .predict (pairs , batch_size = self .batch_size )
117142
118- for i in range ( 0 , len ( pairs ), self . batch_size ):
119- batch = pairs [ i : i + self .batch_size ]
120- self . cross_encoder . predict ( batch )
143+ res = np . concatenate ( self . _activations_list , axis = 0 )
144+ self ._activations_list . clear ()
145+ return res # type: ignore[no-any-return]
121146
122- handler .remove ()
123-
124- return np .concatenate (logits_list , axis = 0 )
125-
126- def _fit (self , pairs : list [list [str ]], labels : list [LabelType ]) -> None :
147+ def _fit (self , pairs : list [tuple [str , str ]], labels : list [LabelType ]) -> None :
127148 """
128149 Train the logistic regression model on cross-encoder features.
129150
@@ -137,8 +158,10 @@ def _fit(self, pairs: list[list[str]], labels: list[LabelType]) -> None:
137158 logger .error (msg )
138159 raise ValueError (msg )
139160
140- features = self .get_features (pairs )
161+ features = self ._get_features_or_predictions (pairs )
141162
163+ # TODO: LogisticRegressionCV has class_weight="balanced". Is it better to use it instead of balance_factor in
164+ # construct_samples?
142165 clf = LogisticRegressionCV ()
143166 clf .fit (features , labels )
144167
@@ -151,18 +174,53 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
151174 :param utterances: List of utterances (texts).
152175 :param labels: Intent class labels corresponding to the utterances.
153176 """
177+ if not self .train_classifier :
178+ return # do nothing if the classifier is not to be re-trained
179+
154180 pairs , labels_ = construct_samples (utterances , labels , balancing_factor = 1 )
155181 self ._fit (pairs , labels_ ) # type: ignore[arg-type]
156182
157- def predict (self , pairs : list [list [ str ]]) -> npt .NDArray [Any ]:
183+ def predict (self , pairs : list [tuple [ str , str ]]) -> npt .NDArray [Any ]:
158184 """
159185 Predict probabilities of two utterances having the same intent label.
160186
161187 :param pairs: List of text pairs to classify.
162188 :return: Numpy array of probabilities.
163189 """
164- features = self .get_features (pairs )
165- return self ._clf .predict_proba (features )[:, 1 ] # type: ignore[no-any-return]
190+ if self .train_classifier and self ._clf is None :
191+ msg = "Classifier is not trained yet"
192+ raise ValueError (msg )
193+
194+ features = self ._get_features_or_predictions (pairs )
195+
196+ if self ._clf is not None :
197+ return np .array (self ._clf .predict_proba (features )[:, 1 ])
198+
199+ return features
200+
201+ def rank (
202+ self ,
203+ query : str ,
204+ query_docs : list [str ],
205+ top_k : int | None = None ,
206+ ) -> list [dict [str , Any ]]:
207+ """
208+ Rank documents according to meaning closeness to the query.
209+
210+ :param query: The reference document.
211+ :query_docs: List of documents to rank
212+ :top_k: how many document to return
213+ :return: array of dictionaries of ranked items.
214+ """
215+ query_doc_pairs = [(query , doc ) for doc in query_docs ]
216+ scores = self .predict (query_doc_pairs )
217+
218+ if top_k is None :
219+ top_k = len (query_docs )
220+
221+ results = [{"corpus_id" : i , "score" : scores [i ]} for i in range (len (query_docs ))]
222+ results .sort (key = lambda x : x ["score" ], reverse = True )
223+ return results [:top_k ]
166224
167225 def save (self , path : str ) -> None :
168226 """
@@ -178,21 +236,13 @@ def save(self, path: str) -> None:
178236 clf_path = dump_dir / "classifier.joblib"
179237 joblib .dump (self ._clf , clf_path )
180238
181- def set_classifier (self , clf : LogisticRegressionCV ) -> None :
182- """
183- Set the logistic regression classifier.
184-
185- :param clf: LogisticRegressionCV instance.
186- """
187- self ._clf = clf
188-
189239 @classmethod
190- def load (cls , path : str ) -> "CrossEncoderWithLogreg " :
240+ def load (cls , path : str ) -> "NLITransformer " :
191241 """
192242 Load the model and classifier from disk.
193243
194244 :param path: Directory path containing the saved model and classifier.
195- :return: Initialized CrossEncoderWithLogreg instance.
245+ :return: Initialized NLITransformer instance.
196246 """
197247 dump_dir = Path (path )
198248
@@ -202,9 +252,5 @@ def load(cls, path: str) -> "CrossEncoderWithLogreg":
202252
203253 # Load sentence transformer model
204254 crossencoder_dir = str (dump_dir / "crossencoder" )
205- model = CrossEncoder (crossencoder_dir )
206-
207- res = cls (model )
208- res .set_classifier (clf )
209255
210- return res
256+ return cls ( crossencoder_dir , classifier_head = clf )
0 commit comments