Skip to content

Commit 5c07205

Browse files
committed
fixing tests
1 parent 4c86e30 commit 5c07205

File tree

2 files changed

+12
-24
lines changed

2 files changed

+12
-24
lines changed

autointent/_transformers/_nli_transformer.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
train_classifier: bool = False,
9898
batch_size: int = 326,
9999
max_length: int | None = None,
100+
classifier_head: LogisticRegressionCV | None = None,
100101
) -> None:
101102
"""
102103
Initialize the NLITransformer.
@@ -106,14 +107,16 @@ def __init__(
106107
:param train_classifier: Whether to train a custom classifier, defaults to False.
107108
:param batch_size: Batch size for processing text pairs, defaults to 326.
108109
:param max_length (int, optional): Max length for input sequences for the cross encoder.
110+
:param classifier_head (LogisticRegressionCV, optional): Classifier (to be used in restore procedure mainly).
109111
"""
110112
self.cross_encoder = CrossEncoder(model, trust_remote_code=True, device=device, max_length=max_length) # type: ignore[arg-type]
111-
self.train_classifier = train_classifier
113+
self.train_classifier = False
112114
self.batch_size = batch_size
113115
self.max_length = max_length
114-
self._clf = None
116+
self._clf = classifier_head
115117

116-
if train_classifier:
118+
if classifier_head is not None or train_classifier:
119+
self.train_classifier = True
117120
self._logits_list: list[npt.NDArray[Any]] = []
118121
self._hook_handler = self.cross_encoder.model.classifier.register_forward_hook(self._classifier_hook)
119122

@@ -188,7 +191,7 @@ def predict(self, pairs: list[list[str]]) -> npt.NDArray[Any]:
188191
features = self.get_features(pairs)
189192

190193
if self._clf is not None:
191-
return self._clf.predict_proba(features)[:, 1]
194+
return np.array(self._clf.predict_proba(features)[:, 1])
192195

193196
return features
194197

@@ -230,17 +233,6 @@ def save(self, path: str) -> None:
230233
clf_path = dump_dir / "classifier.joblib"
231234
joblib.dump(self._clf, clf_path)
232235

233-
def set_classifier(self, clf: LogisticRegressionCV) -> None:
234-
"""
235-
Set the logistic regression classifier.
236-
237-
:param clf: LogisticRegressionCV instance.
238-
"""
239-
self._clf = clf
240-
241-
if clf is None:
242-
self.train_classifier = False
243-
244236
@classmethod
245237
def load(cls, path: str) -> "NLITransformer":
246238
"""
@@ -257,9 +249,5 @@ def load(cls, path: str) -> "NLITransformer":
257249

258250
# Load sentence transformer model
259251
crossencoder_dir = str(dump_dir / "crossencoder")
260-
model = CrossEncoder(crossencoder_dir)
261-
262-
res = cls(model)
263-
res.set_classifier(clf)
264252

265-
return res
253+
return cls(crossencoder_dir, classifier_head=clf)

tests/_transformers/test_nli_transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def data_handler():
1414

1515

1616
def test_nli_transformer_predict_without_trained_head(data_handler):
17-
model = NLITransformer(model="cross-encoder/ms-marco-MiniLM-L-6-v2", device="cuda", train_classifier=True)
17+
model = NLITransformer(model="cross-encoder/ms-marco-MiniLM-L-6-v2", device="cpu", train_classifier=True)
1818
with pytest.raises(ValueError, match="Classifier is not trained yet"):
1919
model.predict(data_handler.train_utterances())
2020

@@ -49,7 +49,7 @@ def check_ranking(ranked, labels):
4949

5050

5151
def test_nli_transformer_predict_with_train_head(data_handler):
52-
model = NLITransformer(model="cross-encoder/ms-marco-MiniLM-L-6-v2", device="cuda", train_classifier=True)
52+
model = NLITransformer(model="cross-encoder/ms-marco-MiniLM-L-6-v2", device="cpu", train_classifier=True)
5353
texts = data_handler.train_utterances()
5454
labels = data_handler.train_labels()
5555
model.fit(texts, labels)
@@ -61,7 +61,7 @@ def test_nli_transformer_predict_with_train_head(data_handler):
6161

6262

6363
def test_nli_transformer_predict_default(data_handler):
64-
model = NLITransformer(model="cross-encoder/ms-marco-MiniLM-L-6-v2", device="cuda")
64+
model = NLITransformer(model="cross-encoder/ms-marco-MiniLM-L-6-v2", device="cpu")
6565
texts = data_handler.train_utterances()
6666
labels = data_handler.train_labels()
6767
predicted = model.predict(build_pairs(texts))
@@ -72,7 +72,7 @@ def test_nli_transformer_predict_default(data_handler):
7272

7373

7474
def test_nli_transformer_predict_default_with_fit(data_handler):
75-
model = NLITransformer(model="cross-encoder/ms-marco-MiniLM-L-6-v2", device="cuda")
75+
model = NLITransformer(model="cross-encoder/ms-marco-MiniLM-L-6-v2", device="cpu")
7676
texts = data_handler.train_utterances()
7777
labels = data_handler.train_labels()
7878
model.fit(texts, labels)

0 commit comments

Comments
 (0)