Skip to content

Commit fe9a587

Browse files
committed
fix: multilabel and fix scorer metric
1 parent b9ff7c7 commit fe9a587

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

autointent/modules/embedding/_retrieval.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from pathlib import Path
55
from typing import Literal
66

7-
from sklearn.linear_model import LogisticRegressionCV
8-
from sklearn.preprocessing import LabelEncoder
7+
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
8+
from sklearn.multioutput import MultiOutputClassifier
9+
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
910

1011
from autointent import Context, Embedder
1112
from autointent.context.optimization_info import RetrieverArtifact
@@ -103,8 +104,7 @@ def __init__(
103104
self.batch_size = batch_size
104105
self.max_length = max_length
105106
self.embedder_use_cache = embedder_use_cache
106-
self.classifier = LogisticRegressionCV(cv=cv)
107-
self.label_encoder = LabelEncoder()
107+
self.cv = cv
108108

109109
super().__init__(k=k)
110110

@@ -152,6 +152,8 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
152152
:param utterances: List of text data to index.
153153
:param labels: List of corresponding labels for the utterances.
154154
"""
155+
self._multilabel = isinstance(labels[0], list)
156+
155157
self.embedder = Embedder(
156158
device=self.embedder_device,
157159
model_name=self.embedder_name,
@@ -160,6 +162,16 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
160162
use_cache=self.embedder_use_cache,
161163
)
162164
embeddings = self.embedder.embed(utterances)
165+
if self._multilabel:
166+
self.label_encoder = MultiLabelBinarizer()
167+
encoded_labels = self.label_encoder.fit_transform(labels)
168+
base_clf = LogisticRegression()
169+
self.classifier = MultiOutputClassifier(base_clf)
170+
else:
171+
self.label_encoder = LabelEncoder()
172+
encoded_labels = self.label_encoder.fit_transform(labels)
173+
self.classifier = LogisticRegressionCV(cv=self.cv)
174+
163175
self.label_encoder.fit(labels)
164176
encoded_labels = self.label_encoder.transform(labels)
165177
self.classifier.fit(embeddings, encoded_labels)
@@ -192,7 +204,7 @@ def score(
192204
predicted_encoded = self.classifier.predict(embeddings)
193205
predicted_labels = self.label_encoder.inverse_transform(predicted_encoded)
194206

195-
return metric_fn(labels, [predicted_labels])
207+
return metric_fn(labels, predicted_labels.reshape(-1, 1))
196208

197209
def get_assets(self) -> RetrieverArtifact:
198210
"""
@@ -262,6 +274,9 @@ def load(self, path: str) -> None:
262274
self.label_encoder = LabelEncoder()
263275
self.label_encoder.classes_ = self.classifier_metadata["classes"]
264276

277+
def predict(self, utterances: list[str]) -> tuple[list[list[int | list[int]]], list[list[float]], list[list[str]]]:
278+
pass
279+
265280

266281
class RetrievalEmbedding(EmbeddingModule):
267282
r"""

0 commit comments

Comments
 (0)