Skip to content

Commit 128e29c

Browse files
committed
fix: fixed _classifier
1 parent f2536de commit 128e29c

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

autointent/modules/embedding/_logreg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,14 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
113113
"""
114114
self._multilabel = isinstance(labels[0], list)
115115

116-
self.embedder = Embedder(
116+
self._embedder = Embedder(
117117
device=self.embedder_device,
118118
model_name_or_path=self.embedder_name,
119119
batch_size=self.embedder_batch_size,
120120
max_length=self.embedder_max_length,
121121
use_cache=self.embedder_use_cache,
122122
)
123-
embeddings = self.embedder.embed(utterances)
123+
embeddings = self._embedder.embed(utterances)
124124

125125
if self._multilabel:
126126
self._label_encoder = None
@@ -155,7 +155,7 @@ def score(
155155
message = f"Invalid split '{split}' provided. Expected one of 'validation', or 'test'."
156156
raise ValueError(message)
157157

158-
embeddings = self.embedder.embed(utterances)
158+
embeddings = self._embedder.embed(utterances)
159159
probas = self._classifier.predict_proba(embeddings)
160160
if self._multilabel:
161161
probas = np.stack(probas, axis=1)[..., 1]

tests/modules/retrieval/test_logreg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def test_fit_trains_model():
1616
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
1717
module.fit(utterances, labels)
1818

19-
assert module.classifier.coef_ is not None
20-
assert len(module.classifier.coef_) > 0
19+
assert module._classifier.coef_ is not None
20+
assert len(module._classifier.coef_) > 0
2121
assert module.label_encoder.classes_.tolist() == [0, 1]
2222

2323

0 commit comments

Comments
 (0)