1919from torch import nn
2020
2121from autointent .custom_types import ListOfLabels
22+ from autointent .schemas ._schemas import CrossEncoderConfig
2223
2324logger = logging .getLogger (__name__ )
2425
2526
2627class CrossEncoderMetadata (TypedDict ):
2728 model_name : str
2829 train_classifier : bool
29- device : str
30+ device : str | None
3031 max_length : int | None
3132 batch_size : int
3233
@@ -104,32 +105,27 @@ class Ranker:
104105
105106 def __init__ (
106107 self ,
107- model_name : str ,
108- device : str = "cpu" ,
109- train_classifier : bool = False ,
110- batch_size : int = 326 ,
111- max_length : int | None = None ,
108+ cross_encoder_config : CrossEncoderConfig ,
112109 classifier_head : LogisticRegressionCV | None = None ,
113110 ) -> None :
114111 """
115112 Initialize the Ranker.
116113
117- :param model: The cross-encoder hugging face model name to use.
118- :param device: Device to run operations on, e.g., "cpu" or "cuda".
119- :param train_classifier: Whether to train a custom classifier, defaults to False.
120- :param batch_size: Batch size for processing text pairs, defaults to 326.
114+ :param cross_encoder_config: The cross-encoder hugging face model name to use.
121115 :param max_length (int, optional): Max length for input sequences for the cross encoder.
122116 :param classifier_head (LogisticRegressionCV, optional): Classifier (to be used in restore procedure mainly).
123117 """
124- self .model_name = model_name
125- self .device = device
126- self .cross_encoder = st .CrossEncoder (model_name , trust_remote_code = True , device = device , max_length = max_length ) # type: ignore[arg-type]
118+ self .cross_encoder = st .CrossEncoder (
119+ cross_encoder_config .model_name ,
120+ trust_remote_code = True ,
121+ device = cross_encoder_config .device , # type: ignore[arg-type]
122+ max_length = cross_encoder_config .max_length , # type: ignore[arg-type]
123+ )
127124 self .train_classifier = False
128- self .batch_size = batch_size
129- self .max_length = max_length
130125 self ._clf = classifier_head
126+ self .cross_encoder_config = cross_encoder_config
131127
132- if classifier_head is not None or train_classifier :
128+ if classifier_head is not None or cross_encoder_config . train_head :
133129 self .train_classifier = True
134130 self ._activations_list : list [npt .NDArray [Any ]] = []
135131 self ._hook_handler = self .cross_encoder .model .classifier .register_forward_hook (self ._classifier_hook )
@@ -149,10 +145,14 @@ def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDAr
149145 :return: Numpy array of extracted features.
150146 """
151147 if not self .train_classifier :
152- return np .array (self .cross_encoder .predict (pairs , batch_size = self .batch_size , activation_fct = nn .Sigmoid ()))
148+ return np .array (
149+ self .cross_encoder .predict (
150+ pairs , batch_size = self .cross_encoder_config .batch_size , activation_fct = nn .Sigmoid () # type: ignore[arg-type]
151+ )
152+ )
153153
154154 # put the data through, features will be taken in the hook
155- self .cross_encoder .predict (pairs , batch_size = self .batch_size )
155+ self .cross_encoder .predict (pairs , batch_size = self .cross_encoder_config . batch_size ) # type: ignore[arg-type]
156156
157157 res = np .concatenate (self ._activations_list , axis = 0 )
158158 self ._activations_list .clear ()
@@ -222,8 +222,8 @@ def rank(
222222 Rank documents according to meaning closeness to the query.
223223
224224 :param query: The reference document.
225- :query_docs: List of documents to rank
226- :top_k: how many document to return
225+ :param query_docs: List of documents to rank
226+ :param top_k: how many document to return
227227 :return: array of dictionaries of ranked items.
228228 """
229229 query_doc_pairs = [(query , doc ) for doc in query_docs ]
@@ -246,11 +246,11 @@ def save(self, path: str) -> None:
246246 dump_dir .mkdir (parents = True )
247247
248248 metadata = CrossEncoderMetadata (
249- model_name = self .model_name ,
249+ model_name = self .cross_encoder_config . model_name ,
250250 train_classifier = self .train_classifier ,
251- device = self .device ,
252- max_length = self .max_length ,
253- batch_size = self .batch_size ,
251+ device = self .cross_encoder_config . device ,
252+ max_length = self .cross_encoder_config . max_length ,
253+ batch_size = self .cross_encoder_config . batch_size ,
254254 )
255255
256256 with (dump_dir / self .metadata_file_name ).open ("w" ) as file :
@@ -271,4 +271,13 @@ def load(cls, path: Path) -> "Ranker":
271271 with (path / cls .metadata_file_name ).open () as file :
272272 metadata : CrossEncoderMetadata = json .load (file )
273273
274- return cls (** metadata , classifier_head = clf )
274+ return cls (
275+ CrossEncoderConfig (
276+ model_name = metadata ["model_name" ],
277+ device = metadata ["device" ],
278+ max_length = metadata ["max_length" ],
279+ batch_size = metadata ["batch_size" ],
280+ train_head = metadata ["train_classifier" ],
281+ ),
282+ classifier_head = clf ,
283+ )
0 commit comments