@@ -97,6 +97,7 @@ def __init__(
9797 train_classifier : bool = False ,
9898 batch_size : int = 326 ,
9999 max_length : int | None = None ,
100+ classifier_head : LogisticRegressionCV | None = None ,
100101 ) -> None :
101102 """
102103 Initialize the NLITransformer.
@@ -106,14 +107,16 @@ def __init__(
106107 :param train_classifier: Whether to train a custom classifier, defaults to False.
107108 :param batch_size: Batch size for processing text pairs, defaults to 326.
108109 :param max_length (int, optional): Max length for input sequences for the cross encoder.
110+ :param classifier_head (LogisticRegressionCV, optional): Classifier (to be used in restore procedure mainly).
109111 """
110112 self .cross_encoder = CrossEncoder (model , trust_remote_code = True , device = device , max_length = max_length ) # type: ignore[arg-type]
111- self .train_classifier = train_classifier
113+ self .train_classifier = False
112114 self .batch_size = batch_size
113115 self .max_length = max_length
114- self ._clf = None
116+ self ._clf = classifier_head
115117
116- if train_classifier :
118+ if classifier_head is not None or train_classifier :
119+ self .train_classifier = True
117120 self ._logits_list : list [npt .NDArray [Any ]] = []
118121 self ._hook_handler = self .cross_encoder .model .classifier .register_forward_hook (self ._classifier_hook )
119122
@@ -188,7 +191,7 @@ def predict(self, pairs: list[list[str]]) -> npt.NDArray[Any]:
188191 features = self .get_features (pairs )
189192
190193 if self ._clf is not None :
191- return self ._clf .predict_proba (features )[:, 1 ]
194+ return np . array ( self ._clf .predict_proba (features )[:, 1 ])
192195
193196 return features
194197
@@ -230,17 +233,6 @@ def save(self, path: str) -> None:
230233 clf_path = dump_dir / "classifier.joblib"
231234 joblib .dump (self ._clf , clf_path )
232235
233- def set_classifier (self , clf : LogisticRegressionCV ) -> None :
234- """
235- Set the logistic regression classifier.
236-
237- :param clf: LogisticRegressionCV instance.
238- """
239- self ._clf = clf
240-
241- if clf is None :
242- self .train_classifier = False
243-
244236 @classmethod
245237 def load (cls , path : str ) -> "NLITransformer" :
246238 """
@@ -257,9 +249,5 @@ def load(cls, path: str) -> "NLITransformer":
257249
258250 # Load sentence transformer model
259251 crossencoder_dir = str (dump_dir / "crossencoder" )
260- model = CrossEncoder (crossencoder_dir )
261-
262- res = cls (model )
263- res .set_classifier (clf )
264252
265- return res
253+ return cls ( crossencoder_dir , classifier_head = clf )
0 commit comments