Skip to content

Commit cd75d72

Browse files
committed
feat: updated predict() in logreg
1 parent 8eaba9c commit cd75d72

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

autointent/modules/embedding/_logreg.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Literal
44

55
import numpy as np
6+
from numpy.typing import NDArray
67
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
78
from sklearn.multioutput import MultiOutputClassifier
89
from sklearn.preprocessing import LabelEncoder
@@ -155,10 +156,7 @@ def score(
155156
message = f"Invalid split '{split}' provided. Expected one of 'validation', or 'test'."
156157
raise ValueError(message)
157158

158-
embeddings = self._embedder.embed(utterances)
159-
probas = self._classifier.predict_proba(embeddings)
160-
if self._multilabel:
161-
probas = np.stack(probas, axis=1)[..., 1]
159+
probas = self.predict(utterances)
162160
metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
163161
return self.score_metrics((labels, probas), metrics_dict)
164162

@@ -170,5 +168,11 @@ def get_assets(self) -> RetrieverArtifact:
170168
"""
171169
return RetrieverArtifact(embedder_name=self.embedder_name)
172170

173-
def predict(self, utterances: list[str]) -> None:
174-
pass
171+
def predict(self, utterances: list[str]) -> NDArray[np.float64] | list[NDArray[np.float64]]:
172+
embeddings = self._embedder.embed(utterances)
173+
probas = self._classifier.predict_proba(embeddings)
174+
175+
if self._multilabel:
176+
probas = np.stack(probas, axis=1)[..., 1]
177+
178+
return probas

0 commit comments

Comments
 (0)